Skip to content

Commit 938803d

Browse files
jerryzh168pytorchmergebot
authored andcommitted
Add bfloat16 support for per tensor/channel cpu/cuda fake quantize ops (pytorch#139306)
Summary: Fixes https://fb.workplace.com/groups/2240361332735959/permalink/8190736677698365 Test Plan: buck2 test 'fbcode//mode/dev' fbcode//caffe2/test/quantization:test_quantization -- --exact 'caffe2/test/quantization:test_quantization - test_forward_per_channel_cachemask_cpu (caffe2.test.quantization.core.test_workflow_ops.TestFakeQuantizeOps)' buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- --exact 'caffe2/test/quantization:test_quantization - test_forward_per_tensor_cachemask_cpu (caffe2.test.quantization.core.test_workflow_ops.TestFakeQuantizeOps)' buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- --exact 'caffe2/test/quantization:test_quantization - test_forward_per_channel_cachemask_cuda (caffe2.test.quantization.core.test_workflow_ops.TestFakeQuantizeOps)' buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- --exact 'caffe2/test/quantization:test_quantization - test_forward_per_channel_cachemask_cpu (caffe2.test.quantization.core.test_workflow_ops.TestFakeQuantizeOps)' Differential Revision: D65221710 Pull Request resolved: pytorch#139306 Approved by: https://github.com/navsud
1 parent 53c9c19 commit 938803d

File tree

3 files changed

+129
-56
lines changed

3 files changed

+129
-56
lines changed

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp

+46-19
Original file line numberDiff line numberDiff line change
@@ -2540,25 +2540,46 @@ void _fake_quantize_tensor_helper(
25402540
.add_input(input)
25412541
.build();
25422542

2543-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_type_handling", [&] {
2544-
iter_combined.for_each([&](char** data, const int64_t* strides, int64_t n) {
2545-
for (const auto i : c10::irange(n)) {
2546-
scalar_t* output_val = (scalar_t*)(data[0] + i * strides[0]);
2547-
bool* mask_val = (bool*)(data[1] + i * strides[1]);
2548-
scalar_t* input_val = (scalar_t*)(data[2] + i * strides[2]);
2549-
2550-
const auto qval = static_cast<int64_t>(z_point + std::nearbyint(*input_val * inv_scale));
2551-
if (fake_quant_on) {
2552-
*output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc;
2553-
*mask_val = ((quant_min <= qval) && (qval <= quant_max));
2554-
} else {
2555-
*output_val = *input_val;
2556-
*mask_val = 1;
2543+
if (at::isReducedFloatingType(input.scalar_type())) {
2544+
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_type_handling", [&]() {
2545+
iter_combined.for_each([&](char** data, const int64_t* strides, int64_t n) {
2546+
for (const auto i : c10::irange(n)) {
2547+
scalar_t* output_val = (scalar_t*)(data[0] + i * strides[0]);
2548+
bool* mask_val = (bool*)(data[1] + i * strides[1]);
2549+
scalar_t* input_val = (scalar_t*)(data[2] + i * strides[2]);
2550+
2551+
const auto qval = static_cast<int64_t>(z_point + std::nearbyint(*input_val * inv_scale));
2552+
if (fake_quant_on) {
2553+
*output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc;
2554+
*mask_val = ((quant_min <= qval) && (qval <= quant_max));
2555+
} else {
2556+
*output_val = *input_val;
2557+
*mask_val = 1;
2558+
}
25572559
}
2558-
}
2560+
});
2561+
});
2562+
} else {
2563+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_type_handling", [&] {
2564+
iter_combined.for_each([&](char** data, const int64_t* strides, int64_t n) {
2565+
for (const auto i : c10::irange(n)) {
2566+
scalar_t* output_val = (scalar_t*)(data[0] + i * strides[0]);
2567+
bool* mask_val = (bool*)(data[1] + i * strides[1]);
2568+
scalar_t* input_val = (scalar_t*)(data[2] + i * strides[2]);
2569+
2570+
const auto qval = static_cast<int64_t>(z_point + std::nearbyint(*input_val * inv_scale));
2571+
if (fake_quant_on) {
2572+
*output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc;
2573+
*mask_val = ((quant_min <= qval) && (qval <= quant_max));
2574+
} else {
2575+
*output_val = *input_val;
2576+
*mask_val = 1;
2577+
}
2578+
}
2579+
});
25592580
});
2560-
});
25612581
}
2582+
}
25622583

25632584
void fake_quantize_tensor_cachemask_kernel(
25642585
Tensor& output,
@@ -2705,9 +2726,15 @@ void fake_quant_per_channel_cachemask_cpu(
27052726
// TODO(future, optional): read once, write twice. Not done at the moment
27062727
// for simplicity, as we do not expect this to be a bottleneck.
27072728

2708-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "fake_quantize_channel_cachemask_cpu_type_handling", [&] {
2709-
_fake_quant_per_channel_cachemask_cpu_helper<scalar_t>(iter, iter_mask, quant_min, quant_max);
2710-
});
2729+
if (at::isReducedFloatingType(iter.dtype())) {
2730+
AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "fake_quantize_channel_cachemask_cpu_type_handling", [&]() {
2731+
_fake_quant_per_channel_cachemask_cpu_helper<scalar_t>(iter, iter_mask, quant_min, quant_max);
2732+
});
2733+
} else {
2734+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "fake_quantize_channel_cachemask_cpu_type_handling", [&] {
2735+
_fake_quant_per_channel_cachemask_cpu_helper<scalar_t>(iter, iter_mask, quant_min, quant_max);
2736+
});
2737+
}
27112738
}
27122739

27132740

aten/src/ATen/native/quantized/cuda/FakeQuantizeCore.cu

+80-34
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,38 @@ void fake_quantize_tensor_cachemask_kernel_cuda(
3434
.add_output(mask)
3535
.add_input(input)
3636
.build();
37-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_types", [&] {
38-
gpu_kernel_multiple_outputs(
39-
iter,
40-
[=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t, bool> {
41-
const auto qval = static_cast<int64_t>(std::nearbyint(input_val * inv_scale) + zero_point);
42-
return {
43-
// fake_quantized value
44-
(fminf(quant_max, fmaxf(quant_min, qval)) - zero_point) * scale,
45-
// mask for grad
46-
((quant_min <= qval) && (qval <= quant_max))
47-
};
48-
}
49-
);
50-
});
37+
38+
if (at::isReducedFloatingType(input.scalar_type())) {
39+
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_types", [&] {
40+
gpu_kernel_multiple_outputs(
41+
iter,
42+
[=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t, bool> {
43+
const auto qval = static_cast<int64_t>(std::nearbyint(input_val * inv_scale) + zero_point);
44+
return {
45+
// fake_quantized value
46+
(fminf(quant_max, fmaxf(quant_min, qval)) - zero_point) * scale,
47+
// mask for grad
48+
((quant_min <= qval) && (qval <= quant_max))
49+
};
50+
}
51+
);
52+
});
53+
} else {
54+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_types", [&] {
55+
gpu_kernel_multiple_outputs(
56+
iter,
57+
[=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t, bool> {
58+
const auto qval = static_cast<int64_t>(std::nearbyint(input_val * inv_scale) + zero_point);
59+
return {
60+
// fake_quantized value
61+
(fminf(quant_max, fmaxf(quant_min, qval)) - zero_point) * scale,
62+
// mask for grad
63+
((quant_min <= qval) && (qval <= quant_max))
64+
};
65+
}
66+
);
67+
});
68+
}
5169
}
5270

5371
void fake_quantize_tensor_cachemask_tensor_qparams_kernel_cuda(
@@ -68,24 +86,46 @@ void fake_quantize_tensor_cachemask_tensor_qparams_kernel_cuda(
6886
.add_output(mask)
6987
.add_input(input)
7088
.build();
71-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_types", [&] {
72-
gpu_kernel_multiple_outputs(
73-
iter,
74-
[=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t, bool> {
75-
if (*fake_quant_on == 0) {
76-
return {input_val, 1};
89+
90+
if (at::isReducedFloatingType(input.scalar_type())) {
91+
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_types", [&] {
92+
gpu_kernel_multiple_outputs(
93+
iter,
94+
[=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t, bool> {
95+
if (*fake_quant_on == 0) {
96+
return {input_val, 1};
97+
}
98+
float inv_scale = 1.0f / (*scale_ptr);
99+
const auto qval = static_cast<int64_t>(std::nearbyint(input_val * inv_scale) + (*zp_ptr));
100+
return {
101+
// fake_quantized value
102+
(fminf(quant_max, fmaxf(quant_min, qval)) - (*zp_ptr)) * (*scale_ptr),
103+
// mask for grad
104+
((quant_min <= qval) && (qval <= quant_max))
105+
};
77106
}
78-
float inv_scale = 1.0f / (*scale_ptr);
79-
const auto qval = static_cast<int64_t>(std::nearbyint(input_val * inv_scale) + (*zp_ptr));
80-
return {
81-
// fake_quantized value
82-
(fminf(quant_max, fmaxf(quant_min, qval)) - (*zp_ptr)) * (*scale_ptr),
83-
// mask for grad
84-
((quant_min <= qval) && (qval <= quant_max))
85-
};
86-
}
87-
);
88-
});
107+
);
108+
});
109+
} else {
110+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_types", [&] {
111+
gpu_kernel_multiple_outputs(
112+
iter,
113+
[=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t, bool> {
114+
if (*fake_quant_on == 0) {
115+
return {input_val, 1};
116+
}
117+
float inv_scale = 1.0f / (*scale_ptr);
118+
const auto qval = static_cast<int64_t>(std::nearbyint(input_val * inv_scale) + (*zp_ptr));
119+
return {
120+
// fake_quantized value
121+
(fminf(quant_max, fmaxf(quant_min, qval)) - (*zp_ptr)) * (*scale_ptr),
122+
// mask for grad
123+
((quant_min <= qval) && (qval <= quant_max))
124+
};
125+
}
126+
);
127+
});
128+
}
89129
}
90130

91131
void _fake_quantize_grad_learnable_tensor_kernel_cuda(
@@ -181,9 +221,15 @@ void _fake_quant_per_channel_cachemask_cuda_helper(
181221

182222
void fake_quant_per_channel_cachemask_cuda(
183223
TensorIterator &iter, TensorIterator &iter_mask, int64_t quant_min, int64_t quant_max) {
184-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "fake_quantize_channel_cachemask_cpu_type_handling", [&] {
185-
_fake_quant_per_channel_cachemask_cuda_helper<scalar_t>(iter, iter_mask, quant_min, quant_max);
186-
});
224+
if (at::isReducedFloatingType(iter.dtype())) {
225+
AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "fake_quantize_channel_cachemask_cuda_type_handling", [&] {
226+
_fake_quant_per_channel_cachemask_cuda_helper<scalar_t>(iter, iter_mask, quant_min, quant_max);
227+
});
228+
} else {
229+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "fake_quantize_channel_cachemask_cuda_type_handling", [&] {
230+
_fake_quant_per_channel_cachemask_cuda_helper<scalar_t>(iter, iter_mask, quant_min, quant_max);
231+
});
232+
}
187233
}
188234

189235
void _fake_quantize_grad_learnable_channel_kernel_cuda(TensorIterator &iter, int64_t quant_min, int64_t quant_max, float grad_factor) {

test/quantization/core/test_workflow_ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def test_forward_per_tensor_half_precision_numerics(self):
331331
self.assertEqual(Y3, Y3r, rtol=tolerance, atol=tolerance)
332332

333333
def _test_forward_per_tensor_cachemask_impl(self, device):
334-
float_types = (torch.float32, torch.float16, torch.float64)
334+
float_types = (torch.float32, torch.float16, torch.float64, torch.bfloat16)
335335
torch_types = (torch.qint8, torch.quint8)
336336
Xs = (torch.randn(4, 8, device=device), torch.randn(4, 16, device=device)[:, ::2])
337337
tensor_qparam = (True, False)
@@ -698,7 +698,7 @@ def test_forward_per_channel(self, device, X):
698698

699699
def _test_forward_per_channel_cachemask_impl(self, device):
700700
torch_types = (torch.qint8, torch.quint8)
701-
float_types = (torch.float32, torch.float16, torch.float64)
701+
float_types = (torch.float32, torch.float16, torch.float64, torch.bfloat16)
702702
zero_point_types = (torch.int, torch.float32, torch.float16)
703703

704704
for torch_type, float_type, zero_point_type in itertools.product(torch_types, float_types, zero_point_types):
@@ -716,7 +716,7 @@ def _test_forward_per_channel_cachemask_impl(self, device):
716716
X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
717717
Y_prime = torch.fake_quantize_per_channel_affine(
718718
X, scale, zero_point, axis, quant_min, quant_max)
719-
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
719+
torch.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
720720
self.assertTrue(Y.dtype == float_type)
721721

722722
def test_forward_per_channel_cachemask_cpu(self):

0 commit comments

Comments
 (0)