|
| 1 | +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| 2 | +#include <ATen/native/AmpKernels.h> |
| 3 | +#include <ATen/Dispatch.h> |
| 4 | +#include <ATen/core/Tensor.h> |
| 5 | + |
| 6 | +#ifndef AT_PER_OPERATOR_HEADERS |
| 7 | +#include <ATen/Functions.h> |
| 8 | +#include <ATen/NativeFunctions.h> |
| 9 | +#else |
| 10 | +#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale.h> |
| 11 | +#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_native.h> |
| 12 | +#include <ATen/ops/_amp_update_scale.h> |
| 13 | +#include <ATen/ops/_amp_update_scale_native.h> |
| 14 | +#endif |
| 15 | + |
| 16 | +namespace at::native { |
| 17 | + |
| 18 | +void _amp_foreach_non_finite_check_and_unscale_cpu_( |
| 19 | + TensorList scaled_grads, |
| 20 | + at::Tensor& found_inf, |
| 21 | + const at::Tensor& inv_scale) { |
| 22 | + _amp_foreach_non_finite_check_and_unscale_cpu_stub( |
| 23 | + found_inf.device().type(), scaled_grads, found_inf, inv_scale); |
| 24 | +} |
| 25 | + |
| 26 | +at::Tensor& _amp_update_scale_cpu_ ( |
| 27 | + at::Tensor& current_scale, |
| 28 | + at::Tensor& growth_tracker, |
| 29 | + const at::Tensor& found_inf, |
| 30 | + double growth_factor, |
| 31 | + double backoff_factor, |
| 32 | + int64_t growth_interval) { |
| 33 | + return _amp_update_scale_cpu_stub( |
| 34 | + growth_tracker.device().type(), current_scale, growth_tracker, |
| 35 | + found_inf, growth_factor, backoff_factor, growth_interval); |
| 36 | +} |
| 37 | + |
| 38 | +DEFINE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu_stub); |
| 39 | +DEFINE_DISPATCH(_amp_update_scale_cpu_stub); |
| 40 | + |
| 41 | +} // namespace at::native |
0 commit comments