@@ -24,8 +24,8 @@ static PyObject* THPStream_pynew(
24
24
HANDLE_TH_ERRORS
25
25
26
26
int64_t stream_id = -1 ;
27
- int64_t device_type = 0 ;
28
- int64_t device_index = 0 ;
27
+ c10::DeviceType device_type{} ;
28
+ c10::DeviceIndex device_index{} ;
29
29
int64_t priority = 0 ;
30
30
31
31
static torch::PythonArgParser parser ({
@@ -42,27 +42,25 @@ static PyObject* THPStream_pynew(
42
42
auto default_accelerator = at::getAccelerator (false );
43
43
auto device = r.deviceOptional (0 );
44
44
if (device.has_value ()) {
45
- device_type = static_cast < int64_t >( device->type () );
46
- device_index = static_cast < int64_t >( device->index () );
45
+ device_type = device->type ();
46
+ device_index = device->index ();
47
47
// Initialize device guard if device is not None.
48
48
device_guard_ptr = std::make_unique<c10::DeviceGuard>(device.value ());
49
49
} else {
50
50
// If device is None, we will use the current accelerator and index.
51
51
// If the current accelerator is not set, we will use the CPU as device
52
52
// type.
53
- device_type = static_cast <int64_t >(
54
- default_accelerator.value_or (c10::DeviceType::CPU));
55
- c10::impl::VirtualGuardImpl impl{
56
- static_cast <c10::DeviceType>(device_type)};
53
+ device_type = default_accelerator.value_or (c10::DeviceType::CPU);
54
+ c10::impl::VirtualGuardImpl impl{device_type};
57
55
const auto current_device = impl.getDevice ();
58
56
device_index = current_device.index ();
59
57
}
60
58
priority = r.toInt64WithDefault (1 , 0 );
61
59
} else if (r.idx == 1 ) {
62
60
stream_id = r.toInt64WithDefault (0 , -1 );
63
- device_index = r.toInt64WithDefault (1 , 0 );
64
- device_type =
65
- r.toInt64WithDefault (2 , static_cast <int64_t >(c10::DeviceType::CPU));
61
+ device_index = static_cast <c10::DeviceIndex>( r.toInt64WithDefault (1 , 0 ) );
62
+ device_type = static_cast <c10::DeviceType>(
63
+ r.toInt64WithDefault (2 , static_cast <int64_t >(c10::DeviceType::CPU))) ;
66
64
priority = r.toInt64WithDefault (3 , 0 );
67
65
} else {
68
66
TORCH_CHECK (
@@ -84,19 +82,16 @@ static PyObject* THPStream_pynew(
84
82
// manage the lifetime of streams.
85
83
std::optional<c10::Stream> stream_opt;
86
84
if (r.idx == 0 ) {
87
- c10::impl::VirtualGuardImpl impl{static_cast <c10::DeviceType>( device_type) };
85
+ c10::impl::VirtualGuardImpl impl{device_type};
88
86
stream_opt = impl.getNewStream (
89
- c10::Device (static_cast <c10::DeviceType>(device_type), device_index),
90
- static_cast <int >(priority));
87
+ c10::Device (device_type, device_index), static_cast <int >(priority));
91
88
} else {
92
- stream_opt = c10::Stream::unpack3 (
93
- stream_id,
94
- static_cast <c10::DeviceIndex>(device_index),
95
- static_cast <c10::DeviceType>(device_type));
89
+ stream_opt = c10::Stream::unpack3 (stream_id, device_index, device_type);
96
90
}
97
91
98
92
TORCH_CHECK (stream_opt.has_value (), " Failed to create stream" );
99
93
self->stream_id = static_cast <int64_t >(stream_opt->id ());
94
+ // NOLINTNEXTLINE(bugprone-signed-char-misuse)
100
95
self->device_index = static_cast <int64_t >(stream_opt->device_index ());
101
96
self->device_type = static_cast <int64_t >(stream_opt->device_type ());
102
97
@@ -139,7 +134,7 @@ static PyObject* THPStream_query(PyObject* _self, PyObject* noargs) {
139
134
140
135
return PyBool_FromLong (c10::Stream::unpack3 (
141
136
self->stream_id ,
142
- self->device_index ,
137
+ static_cast <c10::DeviceIndex>( self->device_index ) ,
143
138
static_cast <c10::DeviceType>(self->device_type ))
144
139
.query ());
145
140
@@ -153,7 +148,7 @@ static PyObject* THPStream_synchronize(PyObject* _self, PyObject* noargs) {
153
148
154
149
c10::Stream::unpack3 (
155
150
self->stream_id ,
156
- self->device_index ,
151
+ static_cast <c10::DeviceIndex>( self->device_index ) ,
157
152
static_cast <c10::DeviceType>(self->device_type ))
158
153
.synchronize ();
159
154
}
@@ -167,7 +162,7 @@ static PyObject* THPStream_wait_event(PyObject* _self, PyObject* _event) {
167
162
auto event = (THPEvent*)_event;
168
163
c10::Stream::unpack3 (
169
164
self->stream_id ,
170
- self->device_index ,
165
+ static_cast <c10::DeviceIndex>( self->device_index ) ,
171
166
static_cast <c10::DeviceType>(self->device_type ))
172
167
.wait (event->event );
173
168
}
@@ -184,11 +179,11 @@ static PyObject* THPStream_wait_stream(PyObject* _self, PyObject* _other) {
184
179
c10::EventFlag::PYTORCH_DEFAULT);
185
180
new_event.record (c10::Stream::unpack3 (
186
181
other_stream->stream_id ,
187
- other_stream->device_index ,
182
+ static_cast <c10::DeviceIndex>( other_stream->device_index ) ,
188
183
static_cast <c10::DeviceType>(other_stream->device_type )));
189
184
c10::Stream::unpack3 (
190
185
self->stream_id ,
191
- self->device_index ,
186
+ static_cast <c10::DeviceIndex>( self->device_index ) ,
192
187
static_cast <c10::DeviceType>(self->device_type ))
193
188
.wait (new_event);
194
189
}
@@ -229,7 +224,7 @@ static PyObject* THPStream_record_event(
229
224
TORCH_CHECK (new_event, " event must not be null" );
230
225
new_event->event .record (c10::Stream::unpack3 (
231
226
self->stream_id ,
232
- self->device_index ,
227
+ static_cast <c10::DeviceIndex>( self->device_index ) ,
233
228
static_cast <c10::DeviceType>(self->device_type )));
234
229
return (PyObject*)new_event;
235
230
END_HANDLE_TH_ERRORS
0 commit comments