Skip to content

Commit 57491d2

Browse files
Skylion007pytorchmergebot
authored andcommitted
Add bfloat16 + fp16 support to fractional_max_pool for CUDA and CPU (pytorch#116950)
Adds bfloat16 to fractional_max_pool. If op supports fp32 and fp16, it really should support bf16 for the most part. Most but not all ops satisfy this, so I am adding support for the few that do not. Pull Request resolved: pytorch#116950 Approved by: https://github.com/lezcano
1 parent 7d61fa2 commit 57491d2

File tree

5 files changed

+43
-26
lines changed

5 files changed

+43
-26
lines changed

aten/src/ATen/native/FractionalMaxPool2d.cpp

+21-16
Original file line numberDiff line numberDiff line change
@@ -321,21 +321,24 @@ TORCH_IMPL_FUNC(fractional_max_pool2d_out_cpu) (
321321
int64_t inputH = input.size(heightDim);
322322
int64_t inputW = input.size(widthDim);
323323

324-
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
325-
"fractional_max_pool2d_out_frame", [&] {
326-
auto input_data = input.data_ptr<scalar_t>();
327-
auto output_data = output.data_ptr<scalar_t>();
328-
auto indices_data = indices.data_ptr<int64_t>();
329-
auto randomSamples_data = randomSamples.data_ptr<scalar_t>();
330-
fractional_max_pool2d_out_frame<scalar_t>(
331-
input_data,
332-
output_data,
333-
indices_data,
334-
randomSamples_data,
335-
numBatch, numPlanes,
336-
inputW, inputH,
337-
outputW, outputH,
338-
poolSizeW, poolSizeH);
324+
AT_DISPATCH_FLOATING_TYPES_AND2(
325+
kBFloat16,
326+
kHalf,
327+
input.scalar_type(),
328+
"fractional_max_pool2d_out_frame", [&] {
329+
auto input_data = input.data_ptr<scalar_t>();
330+
auto output_data = output.data_ptr<scalar_t>();
331+
auto indices_data = indices.data_ptr<int64_t>();
332+
auto randomSamples_data = randomSamples.data_ptr<scalar_t>();
333+
fractional_max_pool2d_out_frame<scalar_t>(
334+
input_data,
335+
output_data,
336+
indices_data,
337+
randomSamples_data,
338+
numBatch, numPlanes,
339+
inputW, inputH,
340+
outputW, outputH,
341+
poolSizeW, poolSizeH);
339342
}
340343
);
341344
}
@@ -375,7 +378,9 @@ TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cpu) (
375378
auto gradOutput = gradOutput_.contiguous();
376379

377380
/* backprop */
378-
AT_DISPATCH_FLOATING_TYPES(
381+
AT_DISPATCH_FLOATING_TYPES_AND2(
382+
kBFloat16,
383+
kHalf,
379384
input.scalar_type(), "fractional_max_pool2d_backward_out_frame", [&] {
380385
auto gradInput_data = gradInput.data_ptr<scalar_t>();
381386
auto gradOutput_data = gradOutput.data_ptr<scalar_t>();

aten/src/ATen/native/FractionalMaxPool3d.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
237237
auto input = input_.contiguous();
238238
auto randomSamples = randomSamples_.contiguous();
239239

240-
AT_DISPATCH_FLOATING_TYPES(
240+
AT_DISPATCH_FLOATING_TYPES_AND2(
241+
kBFloat16,
242+
kHalf,
241243
input.scalar_type(),
242244
"fractional_max_pool3d_out_frame",
243245
[&] {
@@ -371,7 +373,9 @@ void fractional_max_pool3d_backward_out_cpu_template(
371373
gradInput.zero_();
372374

373375
/* backprop */
374-
AT_DISPATCH_FLOATING_TYPES(
376+
AT_DISPATCH_FLOATING_TYPES_AND2(
377+
kBFloat16,
378+
kHalf,
375379
input.scalar_type(),
376380
"fractional_max_pool3d_backward_out_frame",
377381
[&]{

aten/src/ATen/native/cuda/FractionalMaxPool2d.cu

+8-2
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,10 @@ TORCH_IMPL_FUNC(fractional_max_pool2d_out_cuda) (
180180
input_.size(0));
181181
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
182182

183-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
183+
AT_DISPATCH_FLOATING_TYPES_AND2(
184+
at::ScalarType::Half,
185+
at::ScalarType::BFloat16,
186+
input.scalar_type(),
184187
"fractional_max_pool2d_out_cuda_frame",
185188
[&] {
186189
auto devInput = input_.packed_accessor64<scalar_t, 4>();
@@ -252,7 +255,10 @@ TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cuda)(
252255
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
253256

254257
auto devIndices = indices_.packed_accessor64<int64_t, 4>();
255-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradOutput.scalar_type(),
258+
AT_DISPATCH_FLOATING_TYPES_AND2(
259+
at::ScalarType::Half,
260+
at::ScalarType::BFloat16,
261+
gradOutput.scalar_type(),
256262
"fractional_max_pool2d_backward_out_cuda_frame",
257263
[&] {
258264
auto devGradInput = gradInput_.packed_accessor64<scalar_t, 4>();

aten/src/ATen/native/cuda/FractionalMaxPool3d.cu

+6-2
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ void fractional_max_pool3d_backward_out_cuda_template(
226226
gradInput_.size(0));
227227
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
228228

229-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
229+
AT_DISPATCH_FLOATING_TYPES_AND2(
230+
at::ScalarType::Half,
231+
at::ScalarType::BFloat16,
230232
gradOutput.scalar_type(),
231233
"fractional_max_pool3d_backward_out_frame",
232234
[&] {
@@ -285,7 +287,9 @@ TORCH_IMPL_FUNC(fractional_max_pool3d_out_cuda) (
285287
input_.size(0));
286288
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
287289

288-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
290+
AT_DISPATCH_FLOATING_TYPES_AND2(
291+
at::ScalarType::Half,
292+
at::ScalarType::BFloat16,
289293
input.scalar_type(),
290294
"fractional_max_pool3d_out_frame",
291295
[&]{

torch/testing/_internal/common_methods_invocations.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -13799,8 +13799,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1379913799
wrapper_set_seed(torch.nn.functional.fractional_max_pool2d, input, *args, **kwargs),
1380013800
# vmap does not support random operations
1380113801
check_batched_forward_grad=False,
13802-
dtypes=floating_types(),
13803-
dtypesIfCUDA=floating_types_and(torch.float16),
13802+
dtypes=floating_types_and(torch.bfloat16, torch.float16),
1380413803
test_neg_view=False,
1380513804
sample_inputs_func=sample_inputs_fractional_max_pool2d,
1380613805
decorators=(
@@ -13820,8 +13819,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1382013819
wrapper_set_seed(torch.nn.functional.fractional_max_pool3d, input, *args, **kwargs),
1382113820
# vmap does not support random operations
1382213821
check_batched_forward_grad=False,
13823-
dtypes=floating_types(),
13824-
dtypesIfCUDA=floating_types_and(torch.float16),
13822+
dtypes=floating_types_and(torch.bfloat16, torch.float16),
1382513823
test_neg_view=False,
1382613824
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
1382713825
sample_inputs_func=sample_inputs_fractional_max_pool3d,

0 commit comments

Comments
 (0)