29
29
#include < ATen/ops/_native_batch_norm_legit_native.h>
30
30
#include < ATen/ops/_native_batch_norm_legit_no_training.h>
31
31
#include < ATen/ops/_native_batch_norm_legit_no_training_native.h>
32
+ #include < ATen/ops/_batch_norm_with_update.h>
33
+ #include < ATen/ops/_batch_norm_with_update_native.h>
34
+ #include < ATen/ops/_batch_norm_no_update.h>
35
+ #include < ATen/ops/_batch_norm_no_update_native.h>
36
+ #include < ATen/ops/batch_norm_backward_native.h>
32
37
#include < ATen/ops/alias.h>
33
38
#include < ATen/ops/batch_norm.h>
34
39
#include < ATen/ops/batch_norm_native.h>
@@ -478,10 +483,58 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
478
483
return std::make_tuple (grad_input, grad_weight, grad_bias);
479
484
}
480
485
486
+ BatchNormBackend _select_batch_norm_backend (
487
+ const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
488
+ const Tensor& running_var, bool training, double eps) {
489
+
490
+ auto & ctx = at::globalContext ();
491
+ bool cudnn_enabled = ctx.userEnabledCuDNN ();
492
+
493
+ if (
494
+ input.is_cuda ()
495
+ && input.scalar_type () != at::kBFloat16 && weight.scalar_type () != at::kBFloat16
496
+ && (input.scalar_type () != at::kHalf
497
+ || weight.scalar_type () == at::kFloat )
498
+ && weight.defined () && bias.defined ()
499
+ && ((running_mean.defined () && running_var.defined ())
500
+ || (!running_mean.defined () && !running_var.defined () && training))
501
+ && (input.dim () >= 3 )
502
+ && ((input.sym_size (0 ) <= 880801 && training) // spatial, training
503
+ ||(input.sym_size (0 ) <= 65535 && !training)) // spatial, eval
504
+ && detail::getCUDAHooks ().compiledWithCuDNN ()
505
+ && eps >= detail::getCUDAHooks ().batchnormMinEpsilonCuDNN ()
506
+ && cudnn_enabled && detail::getCUDAHooks ().versionCuDNN () >= 5110L
507
+ && input.sym_numel () < std::numeric_limits<std::int32_t >::max () // some cuDNN kernels have 32-bit indexing limitations
508
+ ) {
509
+ return BatchNormBackend::Cudnn;
510
+ }
511
+
512
+ if (
513
+ input.is_cuda ()
514
+ && input.dim () <= MIOPEN_DIM_MAX
515
+ && input.scalar_type () != at::kDouble
516
+ && input.scalar_type () != at::kBFloat16
517
+ && (weight.scalar_type () != at::kHalf )
518
+ && weight.defined () && bias.defined ()
519
+ && ((running_mean.defined () && running_var.defined ())
520
+ || (!running_mean.defined () && !running_var.defined () && training))
521
+ && detail::getCUDAHooks ().compiledWithMIOpen ()
522
+ && cudnn_enabled
523
+ && input.suggest_memory_format () != MemoryFormat::ChannelsLast
524
+ && input.suggest_memory_format () != MemoryFormat::ChannelsLast3d
525
+ ) {
526
+ return BatchNormBackend::Miopen;
527
+ }
528
+
529
+ return BatchNormBackend::Native;
530
+ }
531
+
532
+
481
533
// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
482
534
// of backends, while enabling it to keep the information about the used backend, so that it can
483
535
// use its corresponding backward implementation.
484
536
// XXX: The indices of backends need to be kept synchronized between this function and its _backward.
537
+ // TODO: remove cudnn_enabled arg
485
538
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t > _batch_norm_impl_index (
486
539
const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */ , const c10::optional<Tensor>& bias_opt /* optional */ , const c10::optional<Tensor>& running_mean_opt /* optional */ , const c10::optional<Tensor>& running_var_opt /* optional */ ,
487
540
bool training, double momentum, double eps, bool cudnn_enabled) {
@@ -526,24 +579,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
526
579
check_dims_match_num_input_features (" bias" , std::move (num_features), bias.sym_numel ());
527
580
}
528
581
529
- const bool use_cudnn = (
530
- input.is_cuda ()
531
- && input.scalar_type () != at::kBFloat16 && weight.scalar_type () != at::kBFloat16
532
- && (input.scalar_type () != at::kHalf
533
- || weight.scalar_type () == at::kFloat )
534
- && weight.defined () && bias.defined ()
535
- && ((running_mean.defined () && running_var.defined ())
536
- || (!running_mean.defined () && !running_var.defined () && training))
537
- && (input.dim () >= 3 )
538
- && ((input.sym_size (0 ) <= 880801 && training) // spatial, training
539
- ||(input.sym_size (0 ) <= 65535 && !training)) // spatial, eval
540
- && detail::getCUDAHooks ().compiledWithCuDNN ()
541
- && eps >= detail::getCUDAHooks ().batchnormMinEpsilonCuDNN ()
542
- && cudnn_enabled && detail::getCUDAHooks ().versionCuDNN () >= 5110L
543
- && input.sym_numel () < std::numeric_limits<std::int32_t >::max () // some cuDNN kernels have 32-bit indexing limitations
544
- );
582
+ BatchNormBackend backend = _select_batch_norm_backend (input, weight, bias, running_mean, running_var, training, eps);
545
583
546
- if (use_cudnn ) {
584
+ if (backend == BatchNormBackend::Cudnn ) {
547
585
auto input_c = input.contiguous (input.suggest_memory_format ());
548
586
auto weight_c = weight.contiguous ();
549
587
auto bias_c = bias.contiguous ();
@@ -560,19 +598,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
560
598
561
599
Tensor reserve = at::empty ({0 }, input.options ().dtype (kByte ));
562
600
563
- bool use_miopen = (input.is_cuda ()
564
- && input.dim () <= MIOPEN_DIM_MAX
565
- && input.scalar_type () != at::kDouble
566
- && input.scalar_type () != at::kBFloat16
567
- && (weight.scalar_type () != at::kHalf )
568
- && weight.defined () && bias.defined ()
569
- && ((running_mean.defined () && running_var.defined ())
570
- || (!running_mean.defined () && !running_var.defined () && training))
571
- && detail::getCUDAHooks ().compiledWithMIOpen ()
572
- && cudnn_enabled
573
- );
574
-
575
- if (use_miopen && input.suggest_memory_format () != MemoryFormat::ChannelsLast && input.suggest_memory_format () != MemoryFormat::ChannelsLast3d) {
601
+ if (backend == BatchNormBackend::Miopen) {
576
602
return std::tuple_cat (
577
603
at::miopen_batch_norm (
578
604
input.contiguous (), weight.contiguous (), bias.contiguous (),
@@ -636,6 +662,7 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
636
662
TORCH_INTERNAL_ASSERT (false , " Unsupported impl_index in _batch_norm_impl_index_backward: " , impl_index);
637
663
}
638
664
665
+ // TODO: remove cudnn_enabled arg
639
666
Tensor batch_norm (
640
667
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
641
668
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
@@ -646,6 +673,30 @@ Tensor batch_norm(
646
673
const Tensor& running_var = c10::value_or_else (running_var_opt, [] {return Tensor ();});
647
674
return std::get<0 >(at::_batch_norm_impl_index (input, weight, bias, running_mean, running_var,
648
675
training, momentum, eps, cudnn_enabled));
676
+ // TODO: switch to the new stack after the 2 week FC window
677
+ // if (training) {
678
+ // BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
679
+ // if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
680
+ // auto input_c = input;
681
+ // if (backend == BatchNormBackend::Cudnn) {
682
+ // input_c = input.contiguous(input.suggest_memory_format());
683
+ // } else {
684
+ // input_c = input.contiguous();
685
+ // }
686
+ // auto weight_c = weight.contiguous();
687
+ // auto bias_c = bias.contiguous();
688
+ // auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
689
+ // auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
690
+ // return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
691
+ // const_cast<Tensor&>(rvar_c), momentum, eps));
692
+ // } else {
693
+ // return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
694
+ // const_cast<Tensor&>(running_var), momentum, eps));
695
+ // }
696
+ // } else {
697
+ // return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
698
+ // momentum, eps));
699
+ // }
649
700
}
650
701
651
702
Tensor instance_norm (
@@ -797,6 +848,38 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const c10:
797
848
return batch_norm_cpu_out (self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
798
849
}
799
850
851
+ std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu (
852
+ const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
853
+ Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
854
+ Tensor output, save_mean, save_var;
855
+ std::tie (output, save_mean, save_var) =
856
+ batch_norm_cpu (input, weight_opt, bias_opt, running_mean, running_var, /* update*/ true , momentum, eps);
857
+ Tensor reserve = at::empty ({0 }, input.options ().dtype (kByte ));
858
+ return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
859
+ }
860
+
861
+ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out (
862
+ const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
863
+ Tensor& running_mean, Tensor& running_var, double momentum, double eps,
864
+ Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
865
+ std::tie (out, save_mean, save_var) =
866
+ batch_norm_cpu_out (input, weight_opt, bias_opt, running_mean, running_var, /* update*/ true , momentum, eps, out, save_mean, save_var);
867
+ return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
868
+ }
869
+
870
+
871
+ std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update (
872
+ const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
873
+ const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
874
+ double momentum, double eps) {
875
+ const Tensor& running_mean = c10::value_or_else (running_mean_opt, [] {return Tensor ();});
876
+ const Tensor& running_var = c10::value_or_else (running_var_opt, [] {return Tensor ();});
877
+ Tensor output, save_mean, save_var;
878
+ std::tie (output, save_mean, save_var) =
879
+ batch_norm_cpu (input, weight_opt, bias_opt, const_cast <Tensor&>(running_mean), const_cast <Tensor&>(running_var), /* update*/ false , momentum, eps);
880
+ Tensor reserve = at::empty ({0 }, input.options ().dtype (kByte ));
881
+ return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
882
+ }
800
883
801
884
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu (
802
885
const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
@@ -825,6 +908,13 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const T
825
908
return batch_norm_cpu_out (self, weight_opt, bias_opt, Tensor (), Tensor (), train, momentum, eps, out, save_mean, save_var);
826
909
}
827
910
911
+ std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu (
912
+ const Tensor& grad_output, const Tensor& input, const Tensor& weight,
913
+ const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
914
+ const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
915
+ bool update, double eps, std::array<bool ,3 > grad_input_mask, const Tensor& reserve) {
916
+ return batch_norm_backward_cpu (grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
917
+ }
828
918
829
919
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu (const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
830
920
bool train, double eps, std::array<bool ,3 > grad_input_mask) {
0 commit comments