From 01f31a5b2b6a9745372bc2e85b069080db5d2224 Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Wed, 6 Mar 2024 15:49:15 +0000 Subject: [PATCH] Add CudnnPadForConvolutions and CudnnVecotrizeConvolutions HLO pass --- xla/service/gpu/BUILD | 3 +++ xla/service/gpu/amdgpu_compiler.cc | 17 +++++++++++++++-- xla/service/gpu/cudnn_pad_for_convolutions.cc | 11 ++++++++--- xla/service/gpu/cudnn_pad_for_convolutions.h | 5 ++++- xla/service/gpu/cudnn_support_utils.cc | 10 ++++++---- xla/service/gpu/cudnn_support_utils.h | 2 +- xla/service/gpu/cudnn_vectorize_convolutions.cc | 9 ++++++--- xla/service/gpu/cudnn_vectorize_convolutions.h | 7 ++++++- xla/stream_executor/device_description.h | 8 ++++++++ 9 files changed, 57 insertions(+), 15 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 104937f2b8f6f..75486ee8ea409 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -3907,6 +3907,9 @@ cc_library( ":conv_algorithm_picker", ":cublas_pad_for_gemms", ":cublas_padding_requirements", + ":cudnn_pad_for_convolutions", + ":cudnn_simplify_padding", + ":cudnn_vectorize_convolutions", ":cusolver_rewriter", ":gemm_algorithm_picker", ":gemm_rewriter", diff --git a/xla/service/gpu/amdgpu_compiler.cc b/xla/service/gpu/amdgpu_compiler.cc index 4eff13885d8b1..e324adac1cab5 100644 --- a/xla/service/gpu/amdgpu_compiler.cc +++ b/xla/service/gpu/amdgpu_compiler.cc @@ -28,6 +28,9 @@ limitations under the License. #include "xla/service/gpu/conv_algorithm_picker.h" #include "xla/service/gpu/cublas_pad_for_gemms.h" #include "xla/service/gpu/cublas_padding_requirements.h" +#include "xla/service/gpu/cudnn_pad_for_convolutions.h" +#include "xla/service/gpu/cudnn_simplify_padding.h" +#include "xla/service/gpu/cudnn_vectorize_convolutions.h" #include "xla/service/gpu/cusolver_rewriter.h" #include "xla/service/gpu/gemm_algorithm_picker.h" #include "xla/service/gpu/gemm_rewriter.h" @@ -88,6 +91,8 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, se::DeviceMemoryAllocator* device_allocator) { + auto rocm_compute_capability = + std::get(gpu_version); // Convert convolutions into CustomCalls to MIOpen, then canonicalize them // (PadInsertion). HloPassPipeline pipeline("conv_canonicalization"); @@ -96,13 +101,14 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( /*allow_mixed_precision=*/false); // Convert upsupported bf16 convolutions to f32. - ConvBfloat16Support conv_bf16_support( - std::get(gpu_version)); + ConvBfloat16Support conv_bf16_support(rocm_compute_capability); pipeline.AddPass(&conv_bf16_support); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(rocm_compute_capability); + pipeline.AddPass(rocm_compute_capability); // The conv padding/vectorization passes which we need to get rid of. They // also leave behind unnecessary tuple/get-tuple-element pairs that @@ -119,6 +125,13 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( options.set_enable_unconditional_reduce_of_concat_replacement(false); pipeline.AddPass>(options); + // CudnnSimplifyPadding gets rid of some padding introduced by + // CudnnPadForConvolutions and used by CudnnVectorizeConvolutions. The + // pattern-matches in this pass need to be run after inlining and simplifying + // tuples from CudnnVectorizeConvolutions. We also need to run algsimp to + // e.g. clean up unnecessary nop `convert`s. + pipeline.AddPass(); + pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); diff --git a/xla/service/gpu/cudnn_pad_for_convolutions.cc b/xla/service/gpu/cudnn_pad_for_convolutions.cc index e104eea0530e6..74a638bd7d676 100644 --- a/xla/service/gpu/cudnn_pad_for_convolutions.cc +++ b/xla/service/gpu/cudnn_pad_for_convolutions.cc @@ -315,7 +315,7 @@ static absl::StatusOr TryResolvePaddedShapesForTensorCore( // Adds padding to cudnn integer convolutions to make input and output feature // maps multiples of pad_to (usually 4 or 32). absl::StatusOr TryResolvePaddedShapesForIntegerConvolution( - int pad_to, const se::CudaComputeCapability& compute_capability, + int pad_to, const se::GpuComputeCapability& compute_capability, HloCustomCallInstruction* conv, std::vector* new_input_shapes_ptr, Shape* new_result_shape_ptr) { TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); @@ -490,13 +490,16 @@ absl::StatusOr CudnnPadForConvolutions::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; + auto *ccc = std::get_if(&compute_capability_); for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { // On Turing and later (sm75+), pad to multiples of 32 bytes if possible, // because that lets us use the fast int8x32 data type. bool local_changed = false; - if (compute_capability_.IsAtLeast(7, 5)) { + bool isSM75_and_later = false; + if (ccc) isSM75_and_later = ccc->IsAtLeast(7, 5); + if (isSM75_and_later || se::isROCm(compute_capability_)) { TF_ASSIGN_OR_RETURN( local_changed, ResolveAndPad(conv, absl::bind_front( @@ -512,7 +515,9 @@ absl::StatusOr CudnnPadForConvolutions::Run( } changed |= local_changed; } - if (compute_capability_.IsAtLeast(se::CudaComputeCapability::VOLTA)) { + bool isVOLTA = false; + if (ccc) isVOLTA = ccc->IsAtLeast(se::CudaComputeCapability::VOLTA); + if (isVOLTA || se::isROCm(compute_capability_)) { for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { TF_ASSIGN_OR_RETURN( bool local_changed, diff --git a/xla/service/gpu/cudnn_pad_for_convolutions.h b/xla/service/gpu/cudnn_pad_for_convolutions.h index e37f45f3e48ad..571f4afdb1698 100644 --- a/xla/service/gpu/cudnn_pad_for_convolutions.h +++ b/xla/service/gpu/cudnn_pad_for_convolutions.h @@ -34,6 +34,9 @@ class CudnnPadForConvolutions : public HloModulePass { explicit CudnnPadForConvolutions(se::CudaComputeCapability compute_capability) : compute_capability_(compute_capability) {} + explicit CudnnPadForConvolutions(se::RocmComputeCapability compute_capability) + : compute_capability_(compute_capability) {} + absl::string_view name() const override { return "cudnn_pad_for_convolutions"; } @@ -44,7 +47,7 @@ class CudnnPadForConvolutions : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - const se::CudaComputeCapability compute_capability_; + const se::GpuComputeCapability compute_capability_; }; } // namespace gpu diff --git a/xla/service/gpu/cudnn_support_utils.cc b/xla/service/gpu/cudnn_support_utils.cc index 7f9cf7074a58a..3294e64c65f38 100644 --- a/xla/service/gpu/cudnn_support_utils.cc +++ b/xla/service/gpu/cudnn_support_utils.cc @@ -33,7 +33,7 @@ namespace xla { namespace gpu { absl::StatusOr CudnnSupportsOptimizedIntegerConvolution( - const se::CudaComputeCapability& compute_capability, + const se::GpuComputeCapability& compute_capability, HloCustomCallInstruction& conv, int vector_size) { TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(&conv)); const Shape& input_shape = conv.operand(0)->shape(); @@ -50,9 +50,11 @@ absl::StatusOr CudnnSupportsOptimizedIntegerConvolution( // Require cc6.1+ for any vectorized integer convolutions // Require cc7.5+ for any IMMA convolutions - if ((vector_size == 32 && !compute_capability.IsAtLeast(7, 5)) || - !compute_capability.IsAtLeast(6, 1)) { - VLOG(3) << "Compute capability " << compute_capability.ToString() + bool isCUDA = std::holds_alternative(compute_capability); + auto cuda_compute_capability = std::get(compute_capability); + if ((vector_size == 32 && !cuda_compute_capability.IsAtLeast(7, 5)) || + !cuda_compute_capability.IsAtLeast(6, 1)) { + VLOG(3) << "Compute capability " << cuda_compute_capability.ToString() << " is not sufficent for int8x" << vector_size << " vectorization."; return false; diff --git a/xla/service/gpu/cudnn_support_utils.h b/xla/service/gpu/cudnn_support_utils.h index f0132f13cd26b..03cd22219b620 100644 --- a/xla/service/gpu/cudnn_support_utils.h +++ b/xla/service/gpu/cudnn_support_utils.h @@ -32,7 +32,7 @@ namespace gpu { // This function does not guarantee that a convolution will be padded and/or // vectorized. It only checks that it is a valid candiate for such optimization. absl::StatusOr CudnnSupportsOptimizedIntegerConvolution( - const se::CudaComputeCapability& compute_capability, + const se::GpuComputeCapability& compute_capability, HloCustomCallInstruction& conv, int vector_size); // Represents configuration for the reshape-transpose-reshape operations that diff --git a/xla/service/gpu/cudnn_vectorize_convolutions.cc b/xla/service/gpu/cudnn_vectorize_convolutions.cc index cecf996c3928c..99e3ae9464cfe 100644 --- a/xla/service/gpu/cudnn_vectorize_convolutions.cc +++ b/xla/service/gpu/cudnn_vectorize_convolutions.cc @@ -335,7 +335,7 @@ absl::Status ReorderInt8NchwVect(HloCustomCallInstruction* conv, // (The dimensions can appear in any order; which is N/C/etc is determined by // the convolutions' dnums.) static absl::StatusOr TryRevectorizeConv( - const se::CudaComputeCapability& compute_capability, + const se::GpuComputeCapability& compute_capability, const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv, int vect_size) { const Shape& input_shape = conv->operand(0)->shape(); @@ -496,7 +496,7 @@ static absl::StatusOr TryRevectorizeConv( // This requires that C be a multiple of vect_size. CudnnPadForConvolutions can // add padding to make this true. static absl::StatusOr TryVectorizeConv( - const se::CudaComputeCapability& compute_capability, + const se::GpuComputeCapability& compute_capability, const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv, int64_t vect_size) { const Shape& input_shape = conv->operand(0)->shape(); @@ -625,7 +625,10 @@ absl::StatusOr CudnnVectorizeConvolutions::Run( // Try to (re)vectorize to int8x32 if this is an sm75+ GPU. If we can't, // fall back to int8x4. bool local_changed = false; - if (compute_capability_.IsAtLeast(7, 5)) { + auto *ccc = std::get_if(&compute_capability_); + bool isSM75_and_later = false; + if (ccc) isSM75_and_later = ccc->IsAtLeast(7, 5); + if (isSM75_and_later || se::isROCm(compute_capability_)) { TF_ASSIGN_OR_RETURN( local_changed, TryRevectorizeConv(compute_capability_, cudnn_version_, conv, 32)); diff --git a/xla/service/gpu/cudnn_vectorize_convolutions.h b/xla/service/gpu/cudnn_vectorize_convolutions.h index 6dde84e023ad7..8cfa3e448ad69 100644 --- a/xla/service/gpu/cudnn_vectorize_convolutions.h +++ b/xla/service/gpu/cudnn_vectorize_convolutions.h @@ -52,6 +52,11 @@ class CudnnVectorizeConvolutions : public HloModulePass { : compute_capability_(compute_capability), cudnn_version_(cudnn_version) {} + explicit CudnnVectorizeConvolutions( + se::RocmComputeCapability compute_capability) + : compute_capability_(compute_capability) {} + + absl::string_view name() const override { return "cudnn_vectorize_convolutions"; } @@ -61,7 +66,7 @@ class CudnnVectorizeConvolutions : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - const se::CudaComputeCapability compute_capability_; + const se::GpuComputeCapability compute_capability_; const se::dnn::VersionInfo cudnn_version_; }; diff --git a/xla/stream_executor/device_description.h b/xla/stream_executor/device_description.h index 6d06755956d9f..d3d2ac48f22fe 100644 --- a/xla/stream_executor/device_description.h +++ b/xla/stream_executor/device_description.h @@ -223,6 +223,14 @@ class RocmComputeCapability { using GpuComputeCapability = std::variant; +static inline bool isCUDA(const GpuComputeCapability& gcc) { + return std::holds_alternative(gcc); +} + +static inline bool isROCm(const GpuComputeCapability& gcc) { + return std::holds_alternative(gcc); +} + // Data that describes the execution target of the StreamExecutor, in terms of // important logical parameters. These include dimensionality limits and // physical parameters of interest, such as number of cores present on the