Skip to content

Commit 44df652

Browse files
jiayisunxpytorchmergebot
authored andcommitted
add Half/BFloat16 support for grid_sample on CPU (pytorch#134812)
Fix pytorch#127224. Pull Request resolved: pytorch#134812 Approved by: https://github.com/Skylion007, https://github.com/mingfeima
1 parent d558c1a commit 44df652

File tree

6 files changed

+13
-85
lines changed

6 files changed

+13
-85
lines changed

aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h

+1-67
Original file line numberDiff line numberDiff line change
@@ -221,73 +221,7 @@ static_assert(
221221
}
222222
template <int64_t mask>
223223
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
224-
__at_align__ int16_t tmp_values[size()];
225-
a.store(tmp_values);
226-
if (mask & 0x01)
227-
tmp_values[0] = b.values[31];
228-
if (mask & 0x02)
229-
tmp_values[1] = b.values[30];
230-
if (mask & 0x04)
231-
tmp_values[2] = b.values[29];
232-
if (mask & 0x08)
233-
tmp_values[3] = b.values[28];
234-
if (mask & 0x10)
235-
tmp_values[4] = b.values[27];
236-
if (mask & 0x20)
237-
tmp_values[5] = b.values[26];
238-
if (mask & 0x40)
239-
tmp_values[6] = b.values[25];
240-
if (mask & 0x80)
241-
tmp_values[7] = b.values[24];
242-
if (mask & 0x100)
243-
tmp_values[8] = b.values[23];
244-
if (mask & 0x200)
245-
tmp_values[9] = b.values[22];
246-
if (mask & 0x400)
247-
tmp_values[10] = b.values[21];
248-
if (mask & 0x800)
249-
tmp_values[11] = b.values[20];
250-
if (mask & 0x1000)
251-
tmp_values[12] = b.values[19];
252-
if (mask & 0x2000)
253-
tmp_values[13] = b.values[18];
254-
if (mask & 0x4000)
255-
tmp_values[14] = b.values[17];
256-
if (mask & 0x8000)
257-
tmp_values[15] = b.values[16];
258-
if (mask & 0x10000)
259-
tmp_values[16] = b.values[15];
260-
if (mask & 0x20000)
261-
tmp_values[17] = b.values[14];
262-
if (mask & 0x40000)
263-
tmp_values[18] = b.values[13];
264-
if (mask & 0x80000)
265-
tmp_values[19] = b.values[12];
266-
if (mask & 0x100000)
267-
tmp_values[20] = b.values[11];
268-
if (mask & 0x200000)
269-
tmp_values[21] = b.values[10];
270-
if (mask & 0x400000)
271-
tmp_values[22] = b.values[9];
272-
if (mask & 0x800000)
273-
tmp_values[23] = b.values[8];
274-
if (mask & 0x1000000)
275-
tmp_values[24] = b.values[7];
276-
if (mask & 0x2000000)
277-
tmp_values[25] = b.values[6];
278-
if (mask & 0x4000000)
279-
tmp_values[26] = b.values[5];
280-
if (mask & 0x8000000)
281-
tmp_values[27] = b.values[4];
282-
if (mask & 0x10000000)
283-
tmp_values[28] = b.values[3];
284-
if (mask & 0x20000000)
285-
tmp_values[29] = b.values[2];
286-
if (mask & 0x40000000)
287-
tmp_values[30] = b.values[1];
288-
if (mask & 0x80000000)
289-
tmp_values[31] = b.values[0];
290-
return loadu(tmp_values);
224+
return _mm512_mask_blend_epi16(mask, a.values, b.values);
291225
}
292226
static Vectorized<T> blendv(const Vectorized<T>& a,
293227
const Vectorized<T>& b, const Vectorized<T>& mask) {

aten/src/ATen/native/GridSampler.cpp

+4-8
Original file line numberDiff line numberDiff line change
@@ -930,9 +930,7 @@ Tensor grid_sampler_2d_cpu(const Tensor& input, const Tensor& grid,
930930
}
931931
// AVX gather instructions use signed 32-bit offsets to gather float values.
932932
// Check for possible overflow and fallback to scalar implementation
933-
if (input.scalar_type() != kDouble) {
934-
TORCH_CHECK(input.scalar_type() == kFloat,
935-
"grid_sampler_2d_cpu not implemented for ", input.scalar_type());
933+
if (input.scalar_type() == kFloat) {
936934
auto sizes = input.sizes();
937935
auto strides = input.strides();
938936
const auto grid_sW = grid.strides()[2];
@@ -968,7 +966,7 @@ Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid,
968966
check_grid_sampler_common(input, grid);
969967
check_grid_sampler_3d(input, grid, interpolation_mode);
970968

971-
return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler3d_cpu", [&] {
969+
return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "grid_sampler3d_cpu", [&] {
972970
return grid_sampler_3d_cpu_impl<scalar_t>(
973971
input, grid, static_cast<GridSamplerInterpolation>(interpolation_mode),
974972
static_cast<GridSamplerPadding>(padding_mode), align_corners);
@@ -986,9 +984,7 @@ grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, con
986984

987985
// AVX gather instructions use signed 32-bit offsets to gather float values.
988986
// Check for possible overflow and fallback to scalar implementation
989-
if (input.scalar_type() != kDouble) {
990-
TORCH_CHECK(input.scalar_type() == kFloat,
991-
"grid_sampler_2d_backward_cpu not implemented for ", input.scalar_type());
987+
if (input.scalar_type() == kFloat) {
992988
auto isizes = input.sizes();
993989
auto istrides = input.strides();
994990
auto gsizes = grad_output.sizes();
@@ -1033,7 +1029,7 @@ grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, con
10331029
check_grid_sampler_common(input, grid);
10341030
check_grid_sampler_3d(input, grid, interpolation_mode);
10351031

1036-
return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_backward_cpu", [&] {
1032+
return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "grid_sampler_3d_backward_cpu", [&] {
10371033
return grid_sampler_3d_backward_cpu_impl<scalar_t>(
10381034
grad_output, input, grid,
10391035
static_cast<GridSamplerInterpolation>(interpolation_mode),

aten/src/ATen/native/cpu/GridSamplerKernel.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1184,7 +1184,7 @@ void grid_sampler_2d_cpu_kernel_impl(
11841184
return; \
11851185
}
11861186

1187-
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_cpu_kernel_impl", [&] {
1187+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "grid_sampler_2d_cpu_kernel_impl", [&] {
11881188
auto out_acc = output.accessor<scalar_t, 4>();
11891189
auto inp_acc = input.accessor<const scalar_t, 4>();
11901190
auto grid_acc = grid.accessor<const scalar_t, 4>();
@@ -1272,7 +1272,7 @@ void grid_sampler_2d_backward_cpu_kernel_impl(
12721272
return; \
12731273
}
12741274

1275-
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_backward_cpu_kernel_impl", [&] {
1275+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "grid_sampler_2d_backward_cpu_kernel_impl", [&] {
12761276
auto gGrid_acc = grad_grid.accessor<scalar_t, 4>();
12771277
auto inp_acc = input.accessor<const scalar_t, 4>();
12781278
auto grid_acc = grid.accessor<const scalar_t, 4>();

test/inductor/test_torchinductor_opinfo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs):
693693
"nn.functional.cosine_similarity": {f16},
694694
"nn.functional.cross_entropy": {f16, f32, f64},
695695
"nn.functional.gaussian_nll_loss": {f16},
696-
"nn.functional.grid_sample": {f32, f64},
696+
"nn.functional.grid_sample": {f32, f64, f16},
697697
"nn.functional.interpolate.area": {f16},
698698
"nn.functional.nll_loss": {f16, f32, f64},
699699
"normal": {f16, f32, f64},

test/test_mps.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def mps_ops_grad_modifier(ops):
152152

153153
MACOS_12_3_XFAILLIST_GRAD = {
154154
# Unsupported Border padding mode, forward pass success as fallback to cpu
155-
'grid_sampler_2d': [torch.float32],
155+
'grid_sampler_2d': [torch.float32, torch.float16, torch.bfloat16],
156156
# Unimplemented
157157
'logaddexp2': [torch.float32],
158158

@@ -165,7 +165,7 @@ def mps_ops_grad_modifier(ops):
165165
'masked.log_softmax': [torch.float32, torch.float16],
166166

167167
# Unsupported Border padding mode, forward pass success as fallback to cpu
168-
'grid_sampler_2d': [torch.float32],
168+
'grid_sampler_2d': [torch.float32, torch.float16, torch.bfloat16],
169169

170170
# Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
171171
# Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
@@ -638,7 +638,7 @@ def mps_ops_modifier(ops):
638638

639639
MACOS_AFTER_13_1_XFAILLIST = {
640640
# before macOS 13.2 it falls back to cpu and pass the forward pass
641-
'grid_sampler_2d': [torch.float32], # Unsupported Border padding mode
641+
'grid_sampler_2d': [torch.float32, torch.float16, torch.bfloat16], # Unsupported Border padding mode
642642
# inconsistency errors between cpu and mps, max seen atol is 2
643643
'nn.functional.interpolatebilinear': [torch.uint8],
644644
}

torch/testing/_internal/common_methods_invocations.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -20811,8 +20811,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
2081120811
),
2081220812
OpInfo(
2081320813
"nn.functional.grid_sample",
20814-
dtypes=floating_types(),
20815-
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
20814+
dtypes=floating_types_and(torch.float16, torch.bfloat16),
2081620815
supports_out=False,
2081720816
sample_inputs_func=sample_inputs_grid_sample,
2081820817
reference_inputs_func=reference_inputs_grid_sample,
@@ -20821,8 +20820,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
2082120820
# TODO: delete this OpInfo once we add meta support for grid_sampler_3d
2082220821
OpInfo(
2082320822
"grid_sampler_2d",
20824-
dtypes=floating_types(),
20825-
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
20823+
dtypes=floating_types_and(torch.float16, torch.bfloat16),
2082620824
supports_out=False,
2082720825
sample_inputs_func=sample_inputs_grid_sampler_2d,
2082820826
supports_gradgrad=False,

0 commit comments

Comments
 (0)