@@ -34,20 +34,38 @@ void fake_quantize_tensor_cachemask_kernel_cuda(
34
34
.add_output (mask)
35
35
.add_input (input)
36
36
.build ();
37
- AT_DISPATCH_FLOATING_TYPES_AND_HALF (input.scalar_type (), " fake_quantize_tensor_cachemask_kernel_types" , [&] {
38
- gpu_kernel_multiple_outputs (
39
- iter,
40
- [=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t , bool > {
41
- const auto qval = static_cast <int64_t >(std::nearbyint (input_val * inv_scale) + zero_point);
42
- return {
43
- // fake_quantized value
44
- (fminf (quant_max, fmaxf (quant_min, qval)) - zero_point) * scale,
45
- // mask for grad
46
- ((quant_min <= qval) && (qval <= quant_max))
47
- };
48
- }
49
- );
50
- });
37
+
38
+ if (at::isReducedFloatingType (input.scalar_type ())) {
39
+ AT_DISPATCH_REDUCED_FLOATING_TYPES (input.scalar_type (), " fake_quantize_tensor_cachemask_kernel_types" , [&] {
40
+ gpu_kernel_multiple_outputs (
41
+ iter,
42
+ [=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t , bool > {
43
+ const auto qval = static_cast <int64_t >(std::nearbyint (input_val * inv_scale) + zero_point);
44
+ return {
45
+ // fake_quantized value
46
+ (fminf (quant_max, fmaxf (quant_min, qval)) - zero_point) * scale,
47
+ // mask for grad
48
+ ((quant_min <= qval) && (qval <= quant_max))
49
+ };
50
+ }
51
+ );
52
+ });
53
+ } else {
54
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (input.scalar_type (), " fake_quantize_tensor_cachemask_kernel_types" , [&] {
55
+ gpu_kernel_multiple_outputs (
56
+ iter,
57
+ [=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t , bool > {
58
+ const auto qval = static_cast <int64_t >(std::nearbyint (input_val * inv_scale) + zero_point);
59
+ return {
60
+ // fake_quantized value
61
+ (fminf (quant_max, fmaxf (quant_min, qval)) - zero_point) * scale,
62
+ // mask for grad
63
+ ((quant_min <= qval) && (qval <= quant_max))
64
+ };
65
+ }
66
+ );
67
+ });
68
+ }
51
69
}
52
70
53
71
void fake_quantize_tensor_cachemask_tensor_qparams_kernel_cuda (
@@ -68,24 +86,46 @@ void fake_quantize_tensor_cachemask_tensor_qparams_kernel_cuda(
68
86
.add_output (mask)
69
87
.add_input (input)
70
88
.build ();
71
- AT_DISPATCH_FLOATING_TYPES_AND_HALF (input.scalar_type (), " fake_quantize_tensor_cachemask_kernel_types" , [&] {
72
- gpu_kernel_multiple_outputs (
73
- iter,
74
- [=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t , bool > {
75
- if (*fake_quant_on == 0 ) {
76
- return {input_val, 1 };
89
+
90
+ if (at::isReducedFloatingType (input.scalar_type ())) {
91
+ AT_DISPATCH_REDUCED_FLOATING_TYPES (input.scalar_type (), " fake_quantize_tensor_cachemask_kernel_types" , [&] {
92
+ gpu_kernel_multiple_outputs (
93
+ iter,
94
+ [=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t , bool > {
95
+ if (*fake_quant_on == 0 ) {
96
+ return {input_val, 1 };
97
+ }
98
+ float inv_scale = 1 .0f / (*scale_ptr);
99
+ const auto qval = static_cast <int64_t >(std::nearbyint (input_val * inv_scale) + (*zp_ptr));
100
+ return {
101
+ // fake_quantized value
102
+ (fminf (quant_max, fmaxf (quant_min, qval)) - (*zp_ptr)) * (*scale_ptr),
103
+ // mask for grad
104
+ ((quant_min <= qval) && (qval <= quant_max))
105
+ };
77
106
}
78
- float inv_scale = 1 .0f / (*scale_ptr);
79
- const auto qval = static_cast <int64_t >(std::nearbyint (input_val * inv_scale) + (*zp_ptr));
80
- return {
81
- // fake_quantized value
82
- (fminf (quant_max, fmaxf (quant_min, qval)) - (*zp_ptr)) * (*scale_ptr),
83
- // mask for grad
84
- ((quant_min <= qval) && (qval <= quant_max))
85
- };
86
- }
87
- );
88
- });
107
+ );
108
+ });
109
+ } else {
110
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (input.scalar_type (), " fake_quantize_tensor_cachemask_kernel_types" , [&] {
111
+ gpu_kernel_multiple_outputs (
112
+ iter,
113
+ [=] GPU_LAMBDA (scalar_t input_val) -> thrust::tuple<scalar_t , bool > {
114
+ if (*fake_quant_on == 0 ) {
115
+ return {input_val, 1 };
116
+ }
117
+ float inv_scale = 1 .0f / (*scale_ptr);
118
+ const auto qval = static_cast <int64_t >(std::nearbyint (input_val * inv_scale) + (*zp_ptr));
119
+ return {
120
+ // fake_quantized value
121
+ (fminf (quant_max, fmaxf (quant_min, qval)) - (*zp_ptr)) * (*scale_ptr),
122
+ // mask for grad
123
+ ((quant_min <= qval) && (qval <= quant_max))
124
+ };
125
+ }
126
+ );
127
+ });
128
+ }
89
129
}
90
130
91
131
void _fake_quantize_grad_learnable_tensor_kernel_cuda (
@@ -181,9 +221,15 @@ void _fake_quant_per_channel_cachemask_cuda_helper(
181
221
182
222
void fake_quant_per_channel_cachemask_cuda (
183
223
TensorIterator &iter, TensorIterator &iter_mask, int64_t quant_min, int64_t quant_max) {
184
- AT_DISPATCH_FLOATING_TYPES_AND_HALF (iter.dtype (), " fake_quantize_channel_cachemask_cpu_type_handling" , [&] {
185
- _fake_quant_per_channel_cachemask_cuda_helper<scalar_t >(iter, iter_mask, quant_min, quant_max);
186
- });
224
+ if (at::isReducedFloatingType (iter.dtype ())) {
225
+ AT_DISPATCH_REDUCED_FLOATING_TYPES (iter.dtype (), " fake_quantize_channel_cachemask_cuda_type_handling" , [&] {
226
+ _fake_quant_per_channel_cachemask_cuda_helper<scalar_t >(iter, iter_mask, quant_min, quant_max);
227
+ });
228
+ } else {
229
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (iter.dtype (), " fake_quantize_channel_cachemask_cuda_type_handling" , [&] {
230
+ _fake_quant_per_channel_cachemask_cuda_helper<scalar_t >(iter, iter_mask, quant_min, quant_max);
231
+ });
232
+ }
187
233
}
188
234
189
235
void _fake_quantize_grad_learnable_channel_kernel_cuda (TensorIterator &iter, int64_t quant_min, int64_t quant_max, float grad_factor) {
0 commit comments