Skip to content

Commit 29516bd

Browse files
CaoEpytorchmergebot
authored andcommitted
add _amp_foreach_non_finite_check_and_unscale_cpu_ and _amp_update_scale_cpu_ kernels on CPU (pytorch#109281)
Step1 of pytorch#111559. Pull Request resolved: pytorch#109281 Approved by: https://github.com/jgong5, https://github.com/ezyang
1 parent 0fa6ee4 commit 29516bd

17 files changed

+457
-22
lines changed

aten/src/ATen/cpu/vec/vec256/vec256_double.h

+4
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ template <> class Vectorized<double> {
100100
Vectorized<double> isnan() const {
101101
return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q);
102102
}
103+
bool has_inf_nan() const {
104+
__m256d self_sub = _mm256_sub_pd(values, values);
105+
return (_mm256_movemask_epi8(_mm256_castpd_si256(self_sub)) & 0x77777777) != 0;
106+
}
103107
Vectorized<double> map(double (*const f)(double)) const {
104108
__at_align__ double tmp[size()];
105109
store(tmp);

aten/src/ATen/cpu/vec/vec256/vec256_float.h

+6
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ template <> class Vectorized<float> {
106106
Vectorized<float> isnan() const {
107107
return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q);
108108
}
109+
110+
bool has_inf_nan() const {
111+
__m256 self_sub = _mm256_sub_ps(values, values);
112+
return (_mm256_movemask_epi8(_mm256_castps_si256(self_sub)) & 0x77777777) != 0;
113+
}
114+
109115
Vectorized<float> map(float (*const f)(float)) const {
110116
__at_align__ float tmp[size()];
111117
store(tmp);

aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h

+10
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,16 @@ template <> class Vectorized<float> {
307307
}
308308
return loadu(res);
309309
};
310+
bool has_inf_nan() const {
311+
__at_align__ float tmp[size()];
312+
store(tmp);
313+
for (const auto i : c10::irange(size())) {
314+
if(_isnan(tmp[i]) || _isinf(tmp[i])) {
315+
return true;
316+
}
317+
}
318+
return false;
319+
}
310320
Vectorized<float> map(float (*const f)(float)) const {
311321
__at_align__ float tmp[size()];
312322
store(tmp);

aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h

+13
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,19 @@ class Vectorized<double> {
383383
auto ret = (x == x);
384384
return ret._nor();
385385
}
386+
bool has_inf_nan() const {
387+
for (const auto i : c10::irange(size()/2)) {
388+
if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
389+
return true;
390+
}
391+
}
392+
for (const auto i : c10::irange(size()/2)) {
393+
if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
394+
return true;
395+
}
396+
}
397+
return false;
398+
}
386399

387400
DEFINE_MEMBER_OP(operator==, double, vec_cmpeq)
388401
DEFINE_MEMBER_OP(operator!=, double, vec_cmpne)

aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h

+14
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,20 @@ class Vectorized<float> {
239239
return (x == v_inf) | (x == v_minus_inf);
240240
}
241241

242+
bool has_inf_nan() const {
243+
for (const auto i : c10::irange(size()/2)) {
244+
if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
245+
return true;
246+
}
247+
}
248+
for (const auto i : c10::irange(size()/2)) {
249+
if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
250+
return true;
251+
}
252+
}
253+
return false;
254+
}
255+
242256
int zero_mask() const {
243257
// returns an integer mask where all zero elements are translated to 1-bit
244258
// and others are translated to 0-bit

aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h

+14
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,20 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
875875
return ret._not();
876876
}
877877

878+
bool has_inf_nan() const {
879+
for (const auto i : c10::irange(size()/2)) {
880+
if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
881+
return true;
882+
}
883+
}
884+
for (const auto i : c10::irange(size()/2)) {
885+
if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
886+
return true;
887+
}
888+
}
889+
return false;
890+
}
891+
878892
template <
879893
typename U = T,
880894
std::enable_if_t<std::is_floating_point<U>::value, int> = 0>

aten/src/ATen/cpu/vec/vec512/vec512_double.h

+4
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ template <> class Vectorized<double> {
106106
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
107107
0xFFFFFFFFFFFFFFFF));
108108
}
109+
bool has_inf_nan() const {
110+
__m512d self_sub = _mm512_sub_pd(values, values);
111+
return (_mm512_movepi8_mask(_mm512_castpd_si512(self_sub)) & 0x7777777777777777) != 0;
112+
}
109113
Vectorized<double> map(double (*const f)(double)) const {
110114
__at_align__ double tmp[size()];
111115
store(tmp);

aten/src/ATen/cpu/vec/vec512/vec512_float.h

+4
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ template <> class Vectorized<float> {
125125
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
126126
0xFFFFFFFF));
127127
}
128+
bool has_inf_nan() const {
129+
__m512 self_sub = _mm512_sub_ps(values, values);
130+
return (_mm512_movepi8_mask(_mm512_castps_si512(self_sub)) & 0x7777777777777777) != 0;
131+
}
128132
Vectorized<float> map(float (*const f)(float)) const {
129133
__at_align__ float tmp[size()];
130134
store(tmp);

aten/src/ATen/cpu/vec/vec_base.h

+8
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,14 @@ struct Vectorized {
255255
}
256256
return vector;
257257
}
258+
bool has_inf_nan() const {
259+
for (int64_t i = 0; i != size(); i++) {
260+
if(_isnan(values[i]) || _isinf(values[i])) {
261+
return true;
262+
}
263+
}
264+
return false;
265+
}
258266
Vectorized<T> map(T (*const f)(T)) const {
259267
Vectorized<T> ret;
260268
for (int64_t i = 0; i != size(); i++) {

aten/src/ATen/native/AmpKernels.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/native/AmpKernels.h>
3+
#include <ATen/Dispatch.h>
4+
#include <ATen/core/Tensor.h>
5+
6+
#ifndef AT_PER_OPERATOR_HEADERS
7+
#include <ATen/Functions.h>
8+
#include <ATen/NativeFunctions.h>
9+
#else
10+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale.h>
11+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_native.h>
12+
#include <ATen/ops/_amp_update_scale.h>
13+
#include <ATen/ops/_amp_update_scale_native.h>
14+
#endif
15+
16+
namespace at::native {
17+
18+
void _amp_foreach_non_finite_check_and_unscale_cpu_(
19+
TensorList scaled_grads,
20+
at::Tensor& found_inf,
21+
const at::Tensor& inv_scale) {
22+
_amp_foreach_non_finite_check_and_unscale_cpu_stub(
23+
found_inf.device().type(), scaled_grads, found_inf, inv_scale);
24+
}
25+
26+
at::Tensor& _amp_update_scale_cpu_ (
27+
at::Tensor& current_scale,
28+
at::Tensor& growth_tracker,
29+
const at::Tensor& found_inf,
30+
double growth_factor,
31+
double backoff_factor,
32+
int64_t growth_interval) {
33+
return _amp_update_scale_cpu_stub(
34+
growth_tracker.device().type(), current_scale, growth_tracker,
35+
found_inf, growth_factor, backoff_factor, growth_interval);
36+
}
37+
38+
DEFINE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu_stub);
39+
DEFINE_DISPATCH(_amp_update_scale_cpu_stub);
40+
41+
} // namespace at::native

aten/src/ATen/native/AmpKernels.h

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#pragma once
2+
3+
#include <ATen/native/DispatchStub.h>
4+
#include <ATen/core/ATen_fwd.h>
5+
6+
namespace at {
7+
class Tensor;
8+
9+
namespace native {
10+
11+
using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
12+
TensorList,
13+
Tensor&,
14+
const Tensor&);
15+
16+
using _amp_update_scale_cpu__fn = Tensor& (*)(
17+
Tensor&,
18+
Tensor&,
19+
const Tensor&,
20+
double,
21+
double,
22+
int64_t);
23+
24+
DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub);
25+
DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub);
26+
27+
} // namespace native
28+
} // namespace at

0 commit comments

Comments
 (0)