29
29
#include < ATen/ops/_native_batch_norm_legit_native.h>
30
30
#include < ATen/ops/_native_batch_norm_legit_no_training.h>
31
31
#include < ATen/ops/_native_batch_norm_legit_no_training_native.h>
32
+ #include < ATen/ops/_batch_norm_with_update.h>
33
+ #include < ATen/ops/_batch_norm_with_update_native.h>
34
+ #include < ATen/ops/_batch_norm_no_update.h>
35
+ #include < ATen/ops/_batch_norm_no_update_native.h>
36
+ #include < ATen/ops/batch_norm_backward_native.h>
32
37
#include < ATen/ops/alias.h>
33
38
#include < ATen/ops/batch_norm.h>
34
39
#include < ATen/ops/batch_norm_native.h>
@@ -479,10 +484,58 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
479
484
return std::make_tuple (grad_input, grad_weight, grad_bias);
480
485
}
481
486
487
+ BatchNormBackend _select_batch_norm_backend (
488
+ const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
489
+ const Tensor& running_var, bool training, double eps) {
490
+
491
+ auto & ctx = at::globalContext ();
492
+ bool cudnn_enabled = ctx.userEnabledCuDNN ();
493
+
494
+ if (
495
+ input.is_cuda ()
496
+ && input.scalar_type () != at::kBFloat16 && weight.scalar_type () != at::kBFloat16
497
+ && (input.scalar_type () != at::kHalf
498
+ || weight.scalar_type () == at::kFloat )
499
+ && weight.defined () && bias.defined ()
500
+ && ((running_mean.defined () && running_var.defined ())
501
+ || (!running_mean.defined () && !running_var.defined () && training))
502
+ && (input.dim () >= 3 )
503
+ && ((input.sym_size (0 ) <= 880801 && training) // spatial, training
504
+ ||(input.sym_size (0 ) <= 65535 && !training)) // spatial, eval
505
+ && detail::getCUDAHooks ().compiledWithCuDNN ()
506
+ && eps >= detail::getCUDAHooks ().batchnormMinEpsilonCuDNN ()
507
+ && cudnn_enabled && detail::getCUDAHooks ().versionCuDNN () >= 5110L
508
+ && input.sym_numel () < std::numeric_limits<std::int32_t >::max () // some cuDNN kernels have 32-bit indexing limitations
509
+ ) {
510
+ return BatchNormBackend::Cudnn;
511
+ }
512
+
513
+ if (
514
+ input.is_cuda ()
515
+ && input.dim () <= MIOPEN_DIM_MAX
516
+ && input.scalar_type () != at::kDouble
517
+ && input.scalar_type () != at::kBFloat16
518
+ && (weight.scalar_type () != at::kHalf )
519
+ && weight.defined () && bias.defined ()
520
+ && ((running_mean.defined () && running_var.defined ())
521
+ || (!running_mean.defined () && !running_var.defined () && training))
522
+ && detail::getCUDAHooks ().compiledWithMIOpen ()
523
+ && cudnn_enabled
524
+ && input.suggest_memory_format () != MemoryFormat::ChannelsLast
525
+ && input.suggest_memory_format () != MemoryFormat::ChannelsLast3d
526
+ ) {
527
+ return BatchNormBackend::Miopen;
528
+ }
529
+
530
+ return BatchNormBackend::Native;
531
+ }
532
+
533
+
482
534
// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
483
535
// of backends, while enabling it to keep the information about the used backend, so that it can
484
536
// use its corresponding backward implementation.
485
537
// XXX: The indices of backends need to be kept synchronized between this function and its _backward.
538
+ // TODO: remove cudnn_enabled arg
486
539
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t > _batch_norm_impl_index (
487
540
const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */ , const c10::optional<Tensor>& bias_opt /* optional */ , const c10::optional<Tensor>& running_mean_opt /* optional */ , const c10::optional<Tensor>& running_var_opt /* optional */ ,
488
541
bool training, double momentum, double eps, bool cudnn_enabled) {
@@ -527,24 +580,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
527
580
check_dims_match_num_input_features (" bias" , std::move (num_features), bias.sym_numel ());
528
581
}
529
582
530
- const bool use_cudnn = (
531
- input.is_cuda ()
532
- && input.scalar_type () != at::kBFloat16 && weight.scalar_type () != at::kBFloat16
533
- && (input.scalar_type () != at::kHalf
534
- || weight.scalar_type () == at::kFloat )
535
- && weight.defined () && bias.defined ()
536
- && ((running_mean.defined () && running_var.defined ())
537
- || (!running_mean.defined () && !running_var.defined () && training))
538
- && (input.dim () >= 3 )
539
- && ((input.sym_size (0 ) <= 880801 && training) // spatial, training
540
- ||(input.sym_size (0 ) <= 65535 && !training)) // spatial, eval
541
- && detail::getCUDAHooks ().compiledWithCuDNN ()
542
- && eps >= detail::getCUDAHooks ().batchnormMinEpsilonCuDNN ()
543
- && cudnn_enabled && detail::getCUDAHooks ().versionCuDNN () >= 5110L
544
- && input.sym_numel () < std::numeric_limits<std::int32_t >::max () // some cuDNN kernels have 32-bit indexing limitations
545
- );
583
+ BatchNormBackend backend = _select_batch_norm_backend (input, weight, bias, running_mean, running_var, training, eps);
546
584
547
- if (use_cudnn ) {
585
+ if (backend == BatchNormBackend::Cudnn ) {
548
586
auto input_c = input.contiguous (input.suggest_memory_format ());
549
587
auto weight_c = weight.contiguous ();
550
588
auto bias_c = bias.contiguous ();
@@ -561,19 +599,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
561
599
562
600
Tensor reserve = at::empty ({0 }, input.options ().dtype (kByte ));
563
601
564
- bool use_miopen = (input.is_cuda ()
565
- && input.dim () <= MIOPEN_DIM_MAX
566
- && input.scalar_type () != at::kDouble
567
- && input.scalar_type () != at::kBFloat16
568
- && (weight.scalar_type () != at::kHalf )
569
- && weight.defined () && bias.defined ()
570
- && ((running_mean.defined () && running_var.defined ())
571
- || (!running_mean.defined () && !running_var.defined () && training))
572
- && detail::getCUDAHooks ().compiledWithMIOpen ()
573
- && cudnn_enabled
574
- );
575
-
576
- if (use_miopen && input.suggest_memory_format () != MemoryFormat::ChannelsLast && input.suggest_memory_format () != MemoryFormat::ChannelsLast3d) {
602
+ if (backend == BatchNormBackend::Miopen) {
577
603
return std::tuple_cat (
578
604
at::miopen_batch_norm (
579
605
input.contiguous (), weight.contiguous (), bias.contiguous (),
@@ -637,6 +663,7 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
637
663
TORCH_INTERNAL_ASSERT (false , " Unsupported impl_index in _batch_norm_impl_index_backward: " , impl_index);
638
664
}
639
665
666
+ // TODO: remove cudnn_enabled arg
640
667
Tensor batch_norm (
641
668
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
642
669
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
@@ -647,6 +674,30 @@ Tensor batch_norm(
647
674
const Tensor& running_var = c10::value_or_else (running_var_opt, [] {return Tensor ();});
648
675
return std::get<0 >(at::_batch_norm_impl_index (input, weight, bias, running_mean, running_var,
649
676
training, momentum, eps, cudnn_enabled));
677
+ // TODO: switch to the new stack after the 2 week FC window
678
+ // if (training) {
679
+ // BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
680
+ // if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
681
+ // auto input_c = input;
682
+ // if (backend == BatchNormBackend::Cudnn) {
683
+ // input_c = input.contiguous(input.suggest_memory_format());
684
+ // } else {
685
+ // input_c = input.contiguous();
686
+ // }
687
+ // auto weight_c = weight.contiguous();
688
+ // auto bias_c = bias.contiguous();
689
+ // auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
690
+ // auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
691
+ // return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
692
+ // const_cast<Tensor&>(rvar_c), momentum, eps));
693
+ // } else {
694
+ // return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
695
+ // const_cast<Tensor&>(running_var), momentum, eps));
696
+ // }
697
+ // } else {
698
+ // return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
699
+ // momentum, eps));
700
+ // }
650
701
}
651
702
652
703
Tensor instance_norm (
@@ -798,6 +849,38 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const c10:
798
849
return batch_norm_cpu_out (self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
799
850
}
800
851
852
+ std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu (
853
+ const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
854
+ Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
855
+ Tensor output, save_mean, save_var;
856
+ std::tie (output, save_mean, save_var) =
857
+ batch_norm_cpu (input, weight_opt, bias_opt, running_mean, running_var, /* update*/ true , momentum, eps);
858
+ Tensor reserve = at::empty ({0 }, input.options ().dtype (kByte ));
859
+ return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
860
+ }
861
+
862
+ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out (
863
+ const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
864
+ Tensor& running_mean, Tensor& running_var, double momentum, double eps,
865
+ Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
866
+ std::tie (out, save_mean, save_var) =
867
+ batch_norm_cpu_out (input, weight_opt, bias_opt, running_mean, running_var, /* update*/ true , momentum, eps, out, save_mean, save_var);
868
+ return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
869
+ }
870
+
871
+
872
+ std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update (
873
+ const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
874
+ const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
875
+ double momentum, double eps) {
876
+ const Tensor& running_mean = c10::value_or_else (running_mean_opt, [] {return Tensor ();});
877
+ const Tensor& running_var = c10::value_or_else (running_var_opt, [] {return Tensor ();});
878
+ Tensor output, save_mean, save_var;
879
+ std::tie (output, save_mean, save_var) =
880
+ batch_norm_cpu (input, weight_opt, bias_opt, const_cast <Tensor&>(running_mean), const_cast <Tensor&>(running_var), /* update*/ false , momentum, eps);
881
+ Tensor reserve = at::empty ({0 }, input.options ().dtype (kByte ));
882
+ return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
883
+ }
801
884
802
885
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu (
803
886
const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
@@ -826,6 +909,13 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const T
826
909
return batch_norm_cpu_out (self, weight_opt, bias_opt, Tensor (), Tensor (), train, momentum, eps, out, save_mean, save_var);
827
910
}
828
911
912
+ std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu (
913
+ const Tensor& grad_output, const Tensor& input, const Tensor& weight,
914
+ const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
915
+ const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
916
+ bool update, double eps, std::array<bool ,3 > grad_input_mask, const Tensor& reserve) {
917
+ return batch_norm_backward_cpu (grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
918
+ }
829
919
830
920
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu (const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
831
921
bool train, double eps, std::array<bool ,3 > grad_input_mask) {
0 commit comments