Skip to content

Commit 773ae81

Browse files
andrewor14pytorchmergebot
authored andcommitted
Batch Norm Consolidation (pytorch#116092)
**Summary:** This commit simplifies the existing decomposition hierarchy of batch norm ops by adding a single, backend agnostic op: `batch_norm_with_update`. The existing hierarchy looks like: ``` aten.batch_norm -> aten._batch_norm_impl_index -> [ aten.native_batch_norm -> aten._native_batch_norm_legit (export only) -> _batch_norm_legit_cpu/cuda (kernels, export only) -> _batch_norm_cpu/cuda (kernels) ] OR [ aten.cudnn_batch_norm ] OR [ aten.miopen_batch_norm ] ``` Aside from complexity, an important problem with the above decomposition hierarchy is cuda numerics in export flows. We observed significantly worse convergence when training a mobilenetv2-like model when using the `_batch_norm_cuda` kernel instead of the `cudnn_batch_norm` kernel. This means users who export their models on CPU first then move the models to cuda later may silently see worse accuracies even when cudnn is installed, because they are using the worse kernel. This issue is summarized in pytorch#111384. Instead, the new hierarchy proposed by consolidating existing batch norm ops will look like: ``` aten.batch_norm -> aten.batch_norm_with_update -> [ _batch_norm_cpu (kernel) ] OR [ _batch_norm_cuda (kernel) ] OR [ cudnn_batch_norm (kernel) ] OR [ miopen_batch_norm (kernel) ] ``` The new op `batch_norm_with_update` hides backend implementation details and automatically picks the right kernel based on what is installed. This commit also adds the following variants to this op: ``` batch_norm_with_update_functional batch_norm_with_update.out batch_norm_no_update batch_norm_no_update.out batch_norm_backward ``` Note that this commit only adds this op and its variants, but does not actually change the decomps to produce these ops in the graph. This will be done after the 2 week FC window, and the ops used in the old stack is planned to be removed after the 6 month BC window. Test Plan: `OpInfo` tests for `batch_norm_with_update`. Reviewers: albanD, bdhirsh Subscribers: albanD, bdhirsh, supriyar Tasks: pytorch#111384 Differential Revision: [D54805279](https://our.internmc.facebook.com/intern/diff/D54805279) Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com> Pull Request resolved: pytorch#116092 Approved by: https://github.com/bdhirsh, https://github.com/albanD
1 parent a17cd22 commit 773ae81

36 files changed

+779
-72
lines changed

aten/src/ATen/native/Normalization.cpp

+120-30
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
#include <ATen/ops/_native_batch_norm_legit_native.h>
3030
#include <ATen/ops/_native_batch_norm_legit_no_training.h>
3131
#include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
32+
#include <ATen/ops/_batch_norm_with_update.h>
33+
#include <ATen/ops/_batch_norm_with_update_native.h>
34+
#include <ATen/ops/_batch_norm_no_update.h>
35+
#include <ATen/ops/_batch_norm_no_update_native.h>
36+
#include <ATen/ops/batch_norm_backward_native.h>
3237
#include <ATen/ops/alias.h>
3338
#include <ATen/ops/batch_norm.h>
3439
#include <ATen/ops/batch_norm_native.h>
@@ -479,10 +484,58 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
479484
return std::make_tuple(grad_input, grad_weight, grad_bias);
480485
}
481486

487+
BatchNormBackend _select_batch_norm_backend(
488+
const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
489+
const Tensor& running_var, bool training, double eps) {
490+
491+
auto& ctx = at::globalContext();
492+
bool cudnn_enabled = ctx.userEnabledCuDNN();
493+
494+
if (
495+
input.is_cuda()
496+
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
497+
&& (input.scalar_type() != at::kHalf
498+
|| weight.scalar_type() == at::kFloat)
499+
&& weight.defined() && bias.defined()
500+
&& ((running_mean.defined() && running_var.defined())
501+
|| (!running_mean.defined() && !running_var.defined() && training))
502+
&& (input.dim() >= 3)
503+
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
504+
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
505+
&& detail::getCUDAHooks().compiledWithCuDNN()
506+
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
507+
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
508+
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
509+
) {
510+
return BatchNormBackend::Cudnn;
511+
}
512+
513+
if (
514+
input.is_cuda()
515+
&& input.dim() <= MIOPEN_DIM_MAX
516+
&& input.scalar_type() != at::kDouble
517+
&& input.scalar_type() != at::kBFloat16
518+
&& (weight.scalar_type() != at::kHalf)
519+
&& weight.defined() && bias.defined()
520+
&& ((running_mean.defined() && running_var.defined())
521+
|| (!running_mean.defined() && !running_var.defined() && training))
522+
&& detail::getCUDAHooks().compiledWithMIOpen()
523+
&& cudnn_enabled
524+
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
525+
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
526+
) {
527+
return BatchNormBackend::Miopen;
528+
}
529+
530+
return BatchNormBackend::Native;
531+
}
532+
533+
482534
// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
483535
// of backends, while enabling it to keep the information about the used backend, so that it can
484536
// use its corresponding backward implementation.
485537
// XXX: The indices of backends need to be kept synchronized between this function and its _backward.
538+
// TODO: remove cudnn_enabled arg
486539
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
487540
const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */,
488541
bool training, double momentum, double eps, bool cudnn_enabled) {
@@ -527,24 +580,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
527580
check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
528581
}
529582

530-
const bool use_cudnn = (
531-
input.is_cuda()
532-
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
533-
&& (input.scalar_type() != at::kHalf
534-
|| weight.scalar_type() == at::kFloat)
535-
&& weight.defined() && bias.defined()
536-
&& ((running_mean.defined() && running_var.defined())
537-
|| (!running_mean.defined() && !running_var.defined() && training))
538-
&& (input.dim() >= 3)
539-
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
540-
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
541-
&& detail::getCUDAHooks().compiledWithCuDNN()
542-
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
543-
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
544-
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
545-
);
583+
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
546584

547-
if (use_cudnn) {
585+
if (backend == BatchNormBackend::Cudnn) {
548586
auto input_c = input.contiguous(input.suggest_memory_format());
549587
auto weight_c = weight.contiguous();
550588
auto bias_c = bias.contiguous();
@@ -561,19 +599,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
561599

562600
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
563601

564-
bool use_miopen = (input.is_cuda()
565-
&& input.dim() <= MIOPEN_DIM_MAX
566-
&& input.scalar_type() != at::kDouble
567-
&& input.scalar_type() != at::kBFloat16
568-
&& (weight.scalar_type() != at::kHalf)
569-
&& weight.defined() && bias.defined()
570-
&& ((running_mean.defined() && running_var.defined())
571-
|| (!running_mean.defined() && !running_var.defined() && training))
572-
&& detail::getCUDAHooks().compiledWithMIOpen()
573-
&& cudnn_enabled
574-
);
575-
576-
if (use_miopen && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d) {
602+
if (backend == BatchNormBackend::Miopen) {
577603
return std::tuple_cat(
578604
at::miopen_batch_norm(
579605
input.contiguous(), weight.contiguous(), bias.contiguous(),
@@ -637,6 +663,7 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
637663
TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
638664
}
639665

666+
// TODO: remove cudnn_enabled arg
640667
Tensor batch_norm(
641668
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
642669
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
@@ -647,6 +674,30 @@ Tensor batch_norm(
647674
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
648675
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
649676
training, momentum, eps, cudnn_enabled));
677+
// TODO: switch to the new stack after the 2 week FC window
678+
// if (training) {
679+
// BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
680+
// if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
681+
// auto input_c = input;
682+
// if (backend == BatchNormBackend::Cudnn) {
683+
// input_c = input.contiguous(input.suggest_memory_format());
684+
// } else {
685+
// input_c = input.contiguous();
686+
// }
687+
// auto weight_c = weight.contiguous();
688+
// auto bias_c = bias.contiguous();
689+
// auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
690+
// auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
691+
// return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
692+
// const_cast<Tensor&>(rvar_c), momentum, eps));
693+
// } else {
694+
// return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
695+
// const_cast<Tensor&>(running_var), momentum, eps));
696+
// }
697+
// } else {
698+
// return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
699+
// momentum, eps));
700+
// }
650701
}
651702

652703
Tensor instance_norm(
@@ -798,6 +849,38 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const c10:
798849
return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
799850
}
800851

852+
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu(
853+
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
854+
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
855+
Tensor output, save_mean, save_var;
856+
std::tie(output, save_mean, save_var) =
857+
batch_norm_cpu(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps);
858+
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
859+
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
860+
}
861+
862+
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
863+
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
864+
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
865+
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
866+
std::tie(out, save_mean, save_var) =
867+
batch_norm_cpu_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
868+
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
869+
}
870+
871+
872+
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
873+
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
874+
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
875+
double momentum, double eps) {
876+
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
877+
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
878+
Tensor output, save_mean, save_var;
879+
std::tie(output, save_mean, save_var) =
880+
batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
881+
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
882+
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
883+
}
801884

802885
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
803886
const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
@@ -826,6 +909,13 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const T
826909
return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
827910
}
828911

912+
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
913+
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
914+
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
915+
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
916+
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
917+
return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
918+
}
829919

830920
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
831921
bool train, double eps, std::array<bool,3> grad_input_mask) {

aten/src/ATen/native/Normalization.h

+8
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,12 @@ namespace at::native {
88
using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
99
DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
1010

11+
enum class BatchNormBackend {
12+
Native,
13+
Cudnn,
14+
Miopen,
15+
};
16+
17+
TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
18+
1119
} // namespace at::native

aten/src/ATen/native/cuda/Normalization.cu

+78
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22
#include <ATen/cuda/detail/IndexUtils.cuh>
3+
#include <ATen/detail/CUDAHooksInterface.h>
4+
#include <ATen/native/Normalization.h>
35
#include <ATen/native/TensorIterator.h>
46
#include <ATen/native/ReduceOps.h>
57
#include <ATen/native/Resize.h>
@@ -12,15 +14,21 @@
1214
#include <ATen/Functions.h>
1315
#include <ATen/NativeFunctions.h>
1416
#else
17+
#include <ATen/ops/_batch_norm_with_update_native.h>
18+
#include <ATen/ops/batch_norm_backward_native.h>
1519
#include <ATen/ops/batch_norm_backward_elemt_native.h>
1620
#include <ATen/ops/batch_norm_backward_reduce_native.h>
1721
#include <ATen/ops/batch_norm_elemt_native.h>
1822
#include <ATen/ops/batch_norm_gather_stats_native.h>
1923
#include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
2024
#include <ATen/ops/batch_norm_stats_native.h>
2125
#include <ATen/ops/batch_norm_update_stats_native.h>
26+
#include <ATen/ops/cudnn_batch_norm.h>
27+
#include <ATen/ops/cudnn_batch_norm_backward.h>
2228
#include <ATen/ops/empty_like.h>
2329
#include <ATen/ops/from_blob.h>
30+
#include <ATen/ops/miopen_batch_norm.h>
31+
#include <ATen/ops/miopen_batch_norm_backward.h>
2432
#include <ATen/ops/native_batch_norm_backward_native.h>
2533
#include <ATen/ops/native_batch_norm_native.h>
2634
#include <ATen/ops/scalar_tensor.h>
@@ -473,6 +481,54 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const c10
473481
return std::make_tuple(output, save_mean, save_invstd);
474482
}
475483

484+
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
485+
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
486+
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
487+
// See [Note: hacky wrapper removal for optional tensor]
488+
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
489+
const Tensor& weight = *weight_maybe_owned;
490+
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
491+
Tensor output, save_mean, save_var, reserve;
492+
493+
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
494+
if (backend == BatchNormBackend::Cudnn) {
495+
std::tie(output, save_mean, save_var, reserve) =
496+
at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
497+
} else if (backend == BatchNormBackend::Miopen) {
498+
reserve = at::empty({0}, input.options().dtype(kByte));
499+
std::tie(output, save_mean, save_var) =
500+
at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
501+
} else {
502+
reserve = at::empty({0}, input.options().dtype(kByte));
503+
std::tie(output, save_mean, save_var) =
504+
batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
505+
}
506+
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
507+
}
508+
509+
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
510+
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
511+
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
512+
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
513+
// See [Note: hacky wrapper removal for optional tensor]
514+
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
515+
const Tensor& weight = *weight_maybe_owned;
516+
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
517+
518+
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
519+
if (backend == BatchNormBackend::Cudnn) {
520+
std::tie(out, save_mean, save_var, reserve) =
521+
at::cudnn_batch_norm_out(out, save_mean, save_var, reserve, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
522+
} else if (backend == BatchNormBackend::Miopen) {
523+
std::tie(out, save_mean, save_var) =
524+
at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
525+
} else {
526+
std::tie(out, save_mean, save_var) =
527+
batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
528+
}
529+
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
530+
}
531+
476532
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
477533
return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
478534
}
@@ -489,6 +545,28 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cuda_out(const
489545
return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
490546
}
491547

548+
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
549+
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
550+
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
551+
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
552+
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
553+
const Tensor& dummy_bias = at::empty(1);
554+
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
555+
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
556+
const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
557+
const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});
558+
559+
BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);
560+
561+
if (backend == BatchNormBackend::Cudnn) {
562+
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
563+
} else if (backend == BatchNormBackend::Miopen) {
564+
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
565+
} else {
566+
return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
567+
}
568+
}
569+
492570
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
493571
// See [Note: hacky wrapper removal for optional tensor]
494572
c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);

0 commit comments

Comments
 (0)