@@ -930,9 +930,7 @@ Tensor grid_sampler_2d_cpu(const Tensor& input, const Tensor& grid,
930
930
}
931
931
// AVX gather instructions use signed 32-bit offsets to gather float values.
932
932
// Check for possible overflow and fallback to scalar implementation
933
- if (input.scalar_type () != kDouble ) {
934
- TORCH_CHECK (input.scalar_type () == kFloat ,
935
- " grid_sampler_2d_cpu not implemented for " , input.scalar_type ());
933
+ if (input.scalar_type () == kFloat ) {
936
934
auto sizes = input.sizes ();
937
935
auto strides = input.strides ();
938
936
const auto grid_sW = grid.strides ()[2 ];
@@ -968,7 +966,7 @@ Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid,
968
966
check_grid_sampler_common (input, grid);
969
967
check_grid_sampler_3d (input, grid, interpolation_mode);
970
968
971
- return AT_DISPATCH_FLOATING_TYPES ( input.scalar_type (), " grid_sampler3d_cpu" , [&] {
969
+ return AT_DISPATCH_FLOATING_TYPES_AND2 ( kHalf , kBFloat16 , input.scalar_type (), " grid_sampler3d_cpu" , [&] {
972
970
return grid_sampler_3d_cpu_impl<scalar_t >(
973
971
input, grid, static_cast <GridSamplerInterpolation>(interpolation_mode),
974
972
static_cast <GridSamplerPadding>(padding_mode), align_corners);
@@ -986,9 +984,7 @@ grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, con
986
984
987
985
// AVX gather instructions use signed 32-bit offsets to gather float values.
988
986
// Check for possible overflow and fallback to scalar implementation
989
- if (input.scalar_type () != kDouble ) {
990
- TORCH_CHECK (input.scalar_type () == kFloat ,
991
- " grid_sampler_2d_backward_cpu not implemented for " , input.scalar_type ());
987
+ if (input.scalar_type () == kFloat ) {
992
988
auto isizes = input.sizes ();
993
989
auto istrides = input.strides ();
994
990
auto gsizes = grad_output.sizes ();
@@ -1033,7 +1029,7 @@ grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, con
1033
1029
check_grid_sampler_common (input, grid);
1034
1030
check_grid_sampler_3d (input, grid, interpolation_mode);
1035
1031
1036
- return AT_DISPATCH_FLOATING_TYPES ( input.scalar_type (), " grid_sampler_3d_backward_cpu" , [&] {
1032
+ return AT_DISPATCH_FLOATING_TYPES_AND2 ( kHalf , kBFloat16 , input.scalar_type (), " grid_sampler_3d_backward_cpu" , [&] {
1037
1033
return grid_sampler_3d_backward_cpu_impl<scalar_t >(
1038
1034
grad_output, input, grid,
1039
1035
static_cast <GridSamplerInterpolation>(interpolation_mode),
0 commit comments