Skip to content

Commit 40c098f

Browse files
guangyeypytorchmergebot
authored andcommittedOct 27, 2024
Introduce a device-agnostic runtime API design (pytorch#132204)
# Motivation According to [[RFC]A device-agnostic Python runtime API design for stream-based accelerators](pytorch#128403), this PR intends to introduce a device-agnostic runtime API design. I personally prefer the **Simple Version** APIs that no longer accept the device type as an input argument. It means we will leverage `getAccelerator` to fetch the current accelerator. And it is flexible to expand these APIs to handle multiple types of accelerator scenarios. The design does **NOT** break the previous design philosophies. I also believe that namespace torch.accelerator is better. It lets users know that the APIs they are calling are running on an accelerator rather than CPU. This is important. Meanwhile, we can follow a simple API design principle: 1. Device-agnostic APIs should be placed under the torch.accelerator namespace and not accept a device_type optional parameter. 2. Device-specific APIs should be placed under device-specific submodules. 3. APIS required by both CPU and accelerators should be placed under the torch namespace and accept a device_type optional parameter. Also, I list the pros and cons of **Simple Version** here: Pros: - `torch.accelerator.foo` will have the same input argument as `torch.xxx.foo`, bringing a better user experience; - more concise, facilitate the developer to write a device-agnostic code. Cons: - no obvious drawbacks. # Additional Context I list the new APIs here: ```python torch.accelerator.is_available() -> bool: torch.accelerator.current_accelerator() -> torch.device: torch.accelerator.device_count() -> int: torch.accelerator.current_device_idx() -> int: torch.accelerator.set_device_idx(device: Union[torch.device, str, int, None]) -> None: torch.accelerator.current_stream(device: Union[torch.device, str, int, None]) -> torch.Stream: torch.accelerator.set_stream(stream: torch.Stream) -> None: torch.accelerator.synchronize(device: Union[torch.device, str, int, None]) -> None: ``` According to the discussion with Alban, we decide to change the API name `set_device` to `set_device_idx` and `current_device` to `current_device_idx` for more explicit. And will submit other PR to support device and stream context manager. Pull Request resolved: pytorch#132204 Approved by: https://github.com/EikanWang, https://github.com/abhilash1910, https://github.com/gujinghui, https://github.com/albanD
1 parent 1152726 commit 40c098f

17 files changed

+343
-0
lines changed
 

‎aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h

+9
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,15 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI
216216
C10_HIP_CHECK(hipEventSynchronize(hip_event));
217217
}
218218

219+
// Note: synchronizeDevice can be safely called from any device
220+
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
221+
int orig_device{-1};
222+
C10_HIP_CHECK(hipGetDevice(&orig_device));
223+
C10_HIP_CHECK(hipSetDevice(device_index));
224+
C10_HIP_CHECK(hipDeviceSynchronize());
225+
C10_HIP_CHECK(hipSetDevice(orig_device));
226+
}
227+
219228
void recordDataPtrOnStream(
220229
const c10::DataPtr& data_ptr,
221230
const Stream& stream) const override {

‎aten/src/ATen/mps/MPSGuardImpl.h

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
111111

112112
bool queryEvent(void* event) const override;
113113

114+
void synchronizeDevice(const DeviceIndex device_index) const override;
115+
114116
};
115117

116118
/// A variant of OptionalDeviceGuard that is specialized for MPS.

‎aten/src/ATen/mps/MPSGuardImpl.mm

+4
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,8 @@
4242
return mps_event->query();
4343
}
4444

45+
void MPSGuardImpl::synchronizeDevice(const DeviceIndex device_index) const {
46+
at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
47+
}
48+
4549
} // namespace at::mps

‎build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,7 @@ libtorch_python_xpu_sources = [
795795

796796
libtorch_python_core_sources = [
797797
"torch/csrc/DataLoader.cpp",
798+
"torch/csrc/DeviceAccelerator.cpp",
798799
"torch/csrc/Device.cpp",
799800
"torch/csrc/Dtype.cpp",
800801
"torch/csrc/DynamicTypes.cpp",

‎c10/core/impl/DeviceGuardImplInterface.h

+9
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,15 @@ struct C10_API DeviceGuardImplInterface {
212212
TORCH_CHECK(false, "Backend doesn't support synchronizing events.");
213213
}
214214

215+
/**
216+
* Wait (by blocking the calling thread) until all the work previously
217+
* enqueued on the device has been completed.
218+
*/
219+
virtual void synchronizeDevice(const DeviceIndex /*device_index*/) const {
220+
TORCH_CHECK(
221+
false, "Backend doesn't support synchronizing all streams on device.");
222+
}
223+
215224
/**
216225
* Ensure the caching allocator (if any) is aware that the given DataPtr is
217226
* being used on the given stream, and that it should thus avoid recycling the

‎c10/core/impl/VirtualGuardImpl.h

+4
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface {
9696
return impl_->synchronizeEvent(event);
9797
}
9898

99+
void synchronizeDevice(const DeviceIndex device_index) const override {
100+
return impl_->synchronizeDevice(device_index);
101+
}
102+
99103
private:
100104
const DeviceGuardImplInterface* impl_ = nullptr;
101105
};

‎c10/cuda/impl/CUDAGuardImpl.h

+13
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,19 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
219219
C10_CUDA_CHECK(cudaEventSynchronize(cuda_event));
220220
}
221221

222+
// Note: synchronizeDevice can be safely called from any device
223+
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
224+
DeviceIndex orig_device{-1};
225+
C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device));
226+
C10_CUDA_CHECK(c10::cuda::SetDevice(device_index));
227+
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
228+
if (C10_UNLIKELY(interp)) {
229+
(*interp)->trace_gpu_device_synchronization(c10::kCUDA);
230+
}
231+
C10_CUDA_CHECK(cudaDeviceSynchronize());
232+
C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device));
233+
}
234+
222235
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
223236
const override {
224237
CUDAStream cuda_stream{stream};

‎c10/xpu/impl/XPUGuardImpl.h

+8
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,14 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
163163
xpu_event->wait_and_throw();
164164
}
165165

166+
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
167+
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
168+
if (C10_UNLIKELY(interp)) {
169+
(*interp)->trace_gpu_device_synchronization(c10::kXPU);
170+
}
171+
c10::xpu::syncStreamsOnDevice(device_index);
172+
}
173+
166174
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
167175
const override {
168176
const XPUStream xpu_stream{stream};

‎docs/source/accelerator.rst

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
torch.accelerator
2+
===================================
3+
.. automodule:: torch.accelerator
4+
.. currentmodule:: torch.accelerator
5+
6+
.. autosummary::
7+
:toctree: generated
8+
:nosignatures:
9+
10+
device_count
11+
is_available
12+
current_accelerator
13+
set_device_idx
14+
current_device_idx
15+
set_stream
16+
current_stream
17+
synchronize

‎docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ Features described in this documentation are classified by release status:
6464
torch.amp <amp>
6565
torch.autograd <autograd>
6666
torch.library <library>
67+
accelerator
6768
cpu
6869
cuda
6970
torch.cuda.memory <torch_cuda_memory>

‎torch/_C/__init__.pyi.in

+9
Original file line numberDiff line numberDiff line change
@@ -2183,6 +2183,15 @@ def _set_worker_pids(
21832183
def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs
21842184
def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails
21852185

2186+
# Defined in torch/csrc/DeviceAccelerator.cpp
2187+
def _accelerator_getAccelerator() -> _device: ...
2188+
def _accelerator_deviceCount() -> _int: ...
2189+
def _accelerator_setDeviceIndex(device_index: _int) -> None: ...
2190+
def _accelerator_getDeviceIndex() -> _int: ...
2191+
def _accelerator_setStream(Stream) -> None: ...
2192+
def _accelerator_getStream(device_index: _int) -> Stream: ...
2193+
def _accelerator_synchronizeDevice(device_index: _int) -> None: ...
2194+
21862195
# Defined in torch/csrc/jit/python/python_tracer.cpp
21872196
class TracingState:
21882197
def push_scope(self, scope_name: str) -> None: ...

‎torch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2092,6 +2092,7 @@ def _assert(condition, message):
20922092
__config__ as __config__,
20932093
__future__ as __future__,
20942094
_awaits as _awaits,
2095+
accelerator as accelerator,
20952096
autograd as autograd,
20962097
backends as backends,
20972098
cpu as cpu,

‎torch/accelerator/__init__.py

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
r"""
2+
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
3+
"""
4+
5+
import torch
6+
7+
from ._utils import _device_t, _get_device_index
8+
9+
10+
def device_count() -> int:
11+
r"""Return the number of current :ref:`accelerator<accelerators>` available.
12+
13+
Returns:
14+
int: the number of the current :ref:`accelerator<accelerators>` available.
15+
If there is no available accelerators, return 0.
16+
"""
17+
return torch._C._accelerator_deviceCount()
18+
19+
20+
def is_available() -> bool:
21+
r"""Check if there is an available :ref:`accelerator<accelerators>`.
22+
23+
Returns:
24+
bool: A boolean indicating if there is an available :ref:`accelerator<accelerators>`.
25+
26+
Example::
27+
28+
>>> assert torch.accelerator.is_available() "No available accelerators detected."
29+
"""
30+
return device_count() > 0
31+
32+
33+
def current_accelerator() -> torch.device:
34+
r"""Return the device of the current :ref:`accelerator<accelerators>`.
35+
36+
Returns:
37+
torch.device: return the current accelerator as :class:`torch.device`.
38+
39+
.. note:: The index of the returned :class:`torch.device` will be ``None``, please use
40+
:func:`torch.accelerator.current_device_idx` to know the current index being used.
41+
And ensure to use :func:`torch.accelerator.is_available` to check if there is an available
42+
accelerator. If there is no available accelerator, this function will raise an exception.
43+
44+
Example::
45+
46+
>>> # xdoctest:
47+
>>> if torch.accelerator.is_available():
48+
>>> current_device = torch.accelerator.current_accelerator()
49+
>>> else:
50+
>>> current_device = torch.device("cpu")
51+
>>> if current_device.type == 'cuda':
52+
>>> is_half_supported = torch.cuda.has_half
53+
>>> elif current_device.type == 'xpu':
54+
>>> is_half_supported = torch.xpu.get_device_properties().has_fp16
55+
>>> elif current_device.type == 'cpu':
56+
>>> is_half_supported = True
57+
"""
58+
return torch._C._accelerator_getAccelerator()
59+
60+
61+
def current_device_idx() -> int:
62+
r"""Return the index of a currently selected device for the current :ref:`accelerator<accelerators>`.
63+
64+
Returns:
65+
int: the index of a currently selected device.
66+
"""
67+
return torch._C._accelerator_getDeviceIndex()
68+
69+
70+
def set_device_idx(device: _device_t, /) -> None:
71+
r"""Set the current device index to a given device.
72+
73+
Args:
74+
device (:class:`torch.device`, str, int): a given device that must match the current
75+
:ref:`accelerator<accelerators>` device type.
76+
77+
.. note:: This function is a no-op if this device index is negative.
78+
"""
79+
device_index = _get_device_index(device)
80+
torch._C._accelerator_setDeviceIndex(device_index)
81+
82+
83+
def current_stream(device: _device_t = None, /) -> torch.Stream:
84+
r"""Return the currently selected stream for a given device.
85+
86+
Args:
87+
device (:class:`torch.device`, str, int, optional): a given device that must match the current
88+
:ref:`accelerator<accelerators>` device type. If not given,
89+
use :func:`torch.accelerator.current_device_idx` by default.
90+
91+
Returns:
92+
torch.Stream: the currently selected stream for a given device.
93+
"""
94+
device_index = _get_device_index(device, True)
95+
return torch._C._accelerator_getStream(device_index)
96+
97+
98+
def set_stream(stream: torch.Stream) -> None:
99+
r"""Set the current stream to a given stream.
100+
101+
Args:
102+
stream (torch.Stream): a given stream that must match the current :ref:`accelerator<accelerators>` device type.
103+
104+
.. note:: This function will set the current device index to the device index of the given stream.
105+
"""
106+
torch._C._accelerator_setStream(stream)
107+
108+
109+
def synchronize(device: _device_t = None, /) -> None:
110+
r"""Wait for all kernels in all streams on the given device to complete.
111+
112+
Args:
113+
device (:class:`torch.device`, str, int, optional): device for which to synchronize. It must match
114+
the current :ref:`accelerator<accelerators>` device type. If not given,
115+
use :func:`torch.accelerator.current_device_idx` by default.
116+
117+
.. note:: This function is a no-op if the current :ref:`accelerator<accelerators>` is not initialized.
118+
119+
Example::
120+
121+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
122+
>>> assert torch.accelerator.is_available() "No available accelerators detected."
123+
>>> start_event = torch.Event(enable_timing=True)
124+
>>> end_event = torch.Event(enable_timing=True)
125+
>>> start_event.record()
126+
>>> tensor = torch.randn(100, device=torch.accelerator.current_accelerator())
127+
>>> sum = torch.sum(tensor)
128+
>>> end_event.record()
129+
>>> torch.accelerator.synchronize()
130+
>>> elapsed_time_ms = start_event.elapsed_time(end_event)
131+
"""
132+
device_index = _get_device_index(device, True)
133+
torch._C._accelerator_synchronizeDevice(device_index)
134+
135+
136+
__all__ = [
137+
"current_accelerator",
138+
"current_device_idx",
139+
"current_stream",
140+
"device_count",
141+
"is_available",
142+
"set_device_idx",
143+
"set_stream",
144+
"synchronize",
145+
]

‎torch/accelerator/_utils.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Optional, Union
2+
3+
import torch
4+
from torch import device as _device
5+
6+
7+
_device_t = Union[_device, str, int, None]
8+
9+
10+
def _get_device_index(device: _device_t, optional: bool = False) -> int:
11+
if isinstance(device, int):
12+
return device
13+
if isinstance(device, str):
14+
device = torch.device(device)
15+
device_index: Optional[int] = None
16+
if isinstance(device, torch.device):
17+
if torch.accelerator.current_accelerator() != device.type:
18+
raise ValueError(
19+
f"{device.type} doesn't match the current accelerator {torch.accelerator.current_accelerator()}."
20+
)
21+
device_index = device.index
22+
if device_index is None:
23+
if not optional:
24+
raise ValueError(
25+
f"Expected a torch.device with a specified index or an integer, but got:{device}"
26+
)
27+
return torch.accelerator.current_device_idx()
28+
return device_index

0 commit comments

Comments
 (0)
Please sign in to comment.