Skip to content

Commit 5680f56

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Batch Norm Consolidation (pytorch#116092)
**Summary:** This commit simplifies the existing decomposition hierarchy of batch norm ops by adding a single, backend agnostic op: `batch_norm_with_update`. The existing hierarchy looks like: ``` aten.batch_norm -> aten._batch_norm_impl_index -> [ aten.native_batch_norm -> aten._native_batch_norm_legit (export only) -> _batch_norm_legit_cpu/cuda (kernels, export only) -> _batch_norm_cpu/cuda (kernels) ] OR [ aten.cudnn_batch_norm ] OR [ aten.miopen_batch_norm ] ``` Aside from complexity, an important problem with the above decomposition hierarchy is cuda numerics in export flows. We observed significantly worse convergence when training a mobilenetv2-like model when using the `_batch_norm_cuda` kernel instead of the `cudnn_batch_norm` kernel. This means users who export their models on CPU first then move the models to cuda later may silently see worse accuracies even when cudnn is installed, because they are using the worse kernel. This issue is summarized in pytorch#111384. Instead, the new hierarchy proposed by consolidating existing batch norm ops will look like: ``` aten.batch_norm -> aten.batch_norm_with_update -> [ _batch_norm_cpu (kernel) ] OR [ _batch_norm_cuda (kernel) ] OR [ cudnn_batch_norm (kernel) ] OR [ miopen_batch_norm (kernel) ] ``` The new op `batch_norm_with_update` hides backend implementation details and automatically picks the right kernel based on what is installed. This commit also adds the following variants to this op: ``` batch_norm_with_update_functional batch_norm_with_update.out batch_norm_no_update batch_norm_no_update.out batch_norm_backward ``` Note that this commit only adds this op and its variants, but does not actually change the decomps to produce these ops in the graph. This will be done after the 2 week FC window, and the ops used in the old stack is planned to be removed after the 6 month BC window. Test Plan: `OpInfo` tests for `batch_norm_with_update`. Reviewers: albanD, bdhirsh Subscribers: albanD, bdhirsh, supriyar Tasks: pytorch#111384 Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com> Pull Request resolved: pytorch#116092 Approved by: https://github.com/bdhirsh, https://github.com/albanD
1 parent f72eb5a commit 5680f56

35 files changed

+753
-70
lines changed

aten/src/ATen/native/Normalization.cpp

+120-30
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
#include <ATen/ops/_native_batch_norm_legit_native.h>
3030
#include <ATen/ops/_native_batch_norm_legit_no_training.h>
3131
#include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
32+
#include <ATen/ops/_batch_norm_with_update.h>
33+
#include <ATen/ops/_batch_norm_with_update_native.h>
34+
#include <ATen/ops/_batch_norm_no_update.h>
35+
#include <ATen/ops/_batch_norm_no_update_native.h>
36+
#include <ATen/ops/batch_norm_backward_native.h>
3237
#include <ATen/ops/alias.h>
3338
#include <ATen/ops/batch_norm.h>
3439
#include <ATen/ops/batch_norm_native.h>
@@ -478,10 +483,58 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
478483
return std::make_tuple(grad_input, grad_weight, grad_bias);
479484
}
480485

486+
BatchNormBackend _select_batch_norm_backend(
487+
const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
488+
const Tensor& running_var, bool training, double eps) {
489+
490+
auto& ctx = at::globalContext();
491+
bool cudnn_enabled = ctx.userEnabledCuDNN();
492+
493+
if (
494+
input.is_cuda()
495+
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
496+
&& (input.scalar_type() != at::kHalf
497+
|| weight.scalar_type() == at::kFloat)
498+
&& weight.defined() && bias.defined()
499+
&& ((running_mean.defined() && running_var.defined())
500+
|| (!running_mean.defined() && !running_var.defined() && training))
501+
&& (input.dim() >= 3)
502+
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
503+
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
504+
&& detail::getCUDAHooks().compiledWithCuDNN()
505+
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
506+
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
507+
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
508+
) {
509+
return BatchNormBackend::Cudnn;
510+
}
511+
512+
if (
513+
input.is_cuda()
514+
&& input.dim() <= MIOPEN_DIM_MAX
515+
&& input.scalar_type() != at::kDouble
516+
&& input.scalar_type() != at::kBFloat16
517+
&& (weight.scalar_type() != at::kHalf)
518+
&& weight.defined() && bias.defined()
519+
&& ((running_mean.defined() && running_var.defined())
520+
|| (!running_mean.defined() && !running_var.defined() && training))
521+
&& detail::getCUDAHooks().compiledWithMIOpen()
522+
&& cudnn_enabled
523+
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
524+
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
525+
) {
526+
return BatchNormBackend::Miopen;
527+
}
528+
529+
return BatchNormBackend::Native;
530+
}
531+
532+
481533
// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
482534
// of backends, while enabling it to keep the information about the used backend, so that it can
483535
// use its corresponding backward implementation.
484536
// XXX: The indices of backends need to be kept synchronized between this function and its _backward.
537+
// TODO: remove cudnn_enabled arg
485538
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
486539
const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */,
487540
bool training, double momentum, double eps, bool cudnn_enabled) {
@@ -526,24 +579,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
526579
check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
527580
}
528581

529-
const bool use_cudnn = (
530-
input.is_cuda()
531-
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
532-
&& (input.scalar_type() != at::kHalf
533-
|| weight.scalar_type() == at::kFloat)
534-
&& weight.defined() && bias.defined()
535-
&& ((running_mean.defined() && running_var.defined())
536-
|| (!running_mean.defined() && !running_var.defined() && training))
537-
&& (input.dim() >= 3)
538-
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
539-
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
540-
&& detail::getCUDAHooks().compiledWithCuDNN()
541-
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
542-
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
543-
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
544-
);
582+
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
545583

546-
if (use_cudnn) {
584+
if (backend == BatchNormBackend::Cudnn) {
547585
auto input_c = input.contiguous(input.suggest_memory_format());
548586
auto weight_c = weight.contiguous();
549587
auto bias_c = bias.contiguous();
@@ -560,19 +598,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
560598

561599
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
562600

563-
bool use_miopen = (input.is_cuda()
564-
&& input.dim() <= MIOPEN_DIM_MAX
565-
&& input.scalar_type() != at::kDouble
566-
&& input.scalar_type() != at::kBFloat16
567-
&& (weight.scalar_type() != at::kHalf)
568-
&& weight.defined() && bias.defined()
569-
&& ((running_mean.defined() && running_var.defined())
570-
|| (!running_mean.defined() && !running_var.defined() && training))
571-
&& detail::getCUDAHooks().compiledWithMIOpen()
572-
&& cudnn_enabled
573-
);
574-
575-
if (use_miopen && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d) {
601+
if (backend == BatchNormBackend::Miopen) {
576602
return std::tuple_cat(
577603
at::miopen_batch_norm(
578604
input.contiguous(), weight.contiguous(), bias.contiguous(),
@@ -636,6 +662,7 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
636662
TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
637663
}
638664

665+
// TODO: remove cudnn_enabled arg
639666
Tensor batch_norm(
640667
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
641668
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
@@ -646,6 +673,30 @@ Tensor batch_norm(
646673
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
647674
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
648675
training, momentum, eps, cudnn_enabled));
676+
// TODO: switch to the new stack after the 2 week FC window
677+
// if (training) {
678+
// BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
679+
// if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
680+
// auto input_c = input;
681+
// if (backend == BatchNormBackend::Cudnn) {
682+
// input_c = input.contiguous(input.suggest_memory_format());
683+
// } else {
684+
// input_c = input.contiguous();
685+
// }
686+
// auto weight_c = weight.contiguous();
687+
// auto bias_c = bias.contiguous();
688+
// auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
689+
// auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
690+
// return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
691+
// const_cast<Tensor&>(rvar_c), momentum, eps));
692+
// } else {
693+
// return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
694+
// const_cast<Tensor&>(running_var), momentum, eps));
695+
// }
696+
// } else {
697+
// return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
698+
// momentum, eps));
699+
// }
649700
}
650701

651702
Tensor instance_norm(
@@ -797,6 +848,38 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const c10:
797848
return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
798849
}
799850

851+
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu(
852+
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
853+
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
854+
Tensor output, save_mean, save_var;
855+
std::tie(output, save_mean, save_var) =
856+
batch_norm_cpu(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps);
857+
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
858+
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
859+
}
860+
861+
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
862+
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
863+
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
864+
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
865+
std::tie(out, save_mean, save_var) =
866+
batch_norm_cpu_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
867+
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
868+
}
869+
870+
871+
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
872+
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
873+
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
874+
double momentum, double eps) {
875+
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
876+
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
877+
Tensor output, save_mean, save_var;
878+
std::tie(output, save_mean, save_var) =
879+
batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
880+
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
881+
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
882+
}
800883

801884
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
802885
const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
@@ -825,6 +908,13 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const T
825908
return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
826909
}
827910

911+
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
912+
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
913+
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
914+
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
915+
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
916+
return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
917+
}
828918

829919
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
830920
bool train, double eps, std::array<bool,3> grad_input_mask) {

aten/src/ATen/native/Normalization.h

+8
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,12 @@ namespace at::native {
88
using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
99
DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
1010

11+
enum class BatchNormBackend {
12+
Native,
13+
Cudnn,
14+
Miopen,
15+
};
16+
17+
TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
18+
1119
} // namespace at::native

aten/src/ATen/native/cuda/Normalization.cu

+78
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22
#include <ATen/cuda/detail/IndexUtils.cuh>
3+
#include <ATen/detail/CUDAHooksInterface.h>
4+
#include <ATen/native/Normalization.h>
35
#include <ATen/native/TensorIterator.h>
46
#include <ATen/native/ReduceOps.h>
57
#include <ATen/native/Resize.h>
@@ -12,15 +14,21 @@
1214
#include <ATen/Functions.h>
1315
#include <ATen/NativeFunctions.h>
1416
#else
17+
#include <ATen/ops/_batch_norm_with_update_native.h>
18+
#include <ATen/ops/batch_norm_backward_native.h>
1519
#include <ATen/ops/batch_norm_backward_elemt_native.h>
1620
#include <ATen/ops/batch_norm_backward_reduce_native.h>
1721
#include <ATen/ops/batch_norm_elemt_native.h>
1822
#include <ATen/ops/batch_norm_gather_stats_native.h>
1923
#include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
2024
#include <ATen/ops/batch_norm_stats_native.h>
2125
#include <ATen/ops/batch_norm_update_stats_native.h>
26+
#include <ATen/ops/cudnn_batch_norm.h>
27+
#include <ATen/ops/cudnn_batch_norm_backward.h>
2228
#include <ATen/ops/empty_like.h>
2329
#include <ATen/ops/from_blob.h>
30+
#include <ATen/ops/miopen_batch_norm.h>
31+
#include <ATen/ops/miopen_batch_norm_backward.h>
2432
#include <ATen/ops/native_batch_norm_backward_native.h>
2533
#include <ATen/ops/native_batch_norm_native.h>
2634
#include <ATen/ops/scalar_tensor.h>
@@ -473,6 +481,54 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const c10
473481
return std::make_tuple(output, save_mean, save_invstd);
474482
}
475483

484+
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
485+
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
486+
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
487+
// See [Note: hacky wrapper removal for optional tensor]
488+
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
489+
const Tensor& weight = *weight_maybe_owned;
490+
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
491+
Tensor output, save_mean, save_var, reserve;
492+
493+
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
494+
if (backend == BatchNormBackend::Cudnn) {
495+
std::tie(output, save_mean, save_var, reserve) =
496+
at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
497+
} else if (backend == BatchNormBackend::Miopen) {
498+
reserve = at::empty({0}, input.options().dtype(kByte));
499+
std::tie(output, save_mean, save_var) =
500+
at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
501+
} else {
502+
reserve = at::empty({0}, input.options().dtype(kByte));
503+
std::tie(output, save_mean, save_var) =
504+
batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
505+
}
506+
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
507+
}
508+
509+
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
510+
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
511+
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
512+
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
513+
// See [Note: hacky wrapper removal for optional tensor]
514+
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
515+
const Tensor& weight = *weight_maybe_owned;
516+
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
517+
518+
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
519+
if (backend == BatchNormBackend::Cudnn) {
520+
std::tie(out, save_mean, save_var, reserve) =
521+
at::cudnn_batch_norm_out(out, save_mean, save_var, reserve, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
522+
} else if (backend == BatchNormBackend::Miopen) {
523+
std::tie(out, save_mean, save_var) =
524+
at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
525+
} else {
526+
std::tie(out, save_mean, save_var) =
527+
batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
528+
}
529+
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
530+
}
531+
476532
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
477533
return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
478534
}
@@ -489,6 +545,28 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cuda_out(const
489545
return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
490546
}
491547

548+
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
549+
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
550+
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
551+
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
552+
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
553+
const Tensor& dummy_bias = at::empty(1);
554+
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
555+
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
556+
const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
557+
const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});
558+
559+
BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);
560+
561+
if (backend == BatchNormBackend::Cudnn) {
562+
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
563+
} else if (backend == BatchNormBackend::Miopen) {
564+
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
565+
} else {
566+
return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
567+
}
568+
}
569+
492570
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
493571
// See [Note: hacky wrapper removal for optional tensor]
494572
c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);

0 commit comments

Comments
 (0)