Skip to content

Commit e3aefe2

Browse files
jeffdailypytorchmergebot
authored andcommitted
Revert "Initial Flash Attention support on ROCM (pytorch#114309)" (pytorch#115975)
This reverts commit 5bddbed. Pull Request resolved: pytorch#115975 Approved by: https://github.com/atalman, https://github.com/malfet
1 parent 8283491 commit e3aefe2

File tree

14 files changed

+23
-848
lines changed

14 files changed

+23
-848
lines changed

CMakeLists.txt

+1-12
Original file line numberDiff line numberDiff line change
@@ -735,21 +735,10 @@ endif()
735735
include(cmake/Dependencies.cmake)
736736

737737
# Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now
738-
# TODO: Merge this into cmake_dependent_option as "NOT MSVC AND (USE_CUDA OR USE_ROCM)"
739-
# once cmake_minimum_required is bumped to 3.22
740-
# See https://cmake.org/cmake/help/latest/policy/CMP0127.html for the feature required here.
741-
if(MSVC)
742-
set(CONFIG_FA OFF)
743-
elseif(USE_ROCM OR USE_CUDA)
744-
set(CONFIG_FA ON)
745-
else()
746-
set(CONFIG_FA OFF)
747-
endif()
748-
749738
cmake_dependent_option(
750739
USE_FLASH_ATTENTION
751740
"Whether to build the flash_attention kernel for scaled dot product attention" ON
752-
"CONFIG_FA" OFF)
741+
"USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
753742

754743
# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
755744
cmake_dependent_option(

aten/src/ATen/CMakeLists.txt

+3-34
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,6 @@ file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
164164
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
165165
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
166166

167-
# flash_attention sources
168-
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
169-
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
170-
171167
#Mem_eff attention sources
172168
file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu")
173169
file(GLOB mem_eff_attention_cuda_kernels_cu "native/transformers/cuda/mem_eff_attention/kernels/*.cu")
@@ -179,9 +175,6 @@ if(USE_FLASH_ATTENTION)
179175
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
180176
list(APPEND FLASH_ATTENTION_CUDA_SOURCES ${flash_attention_cuda_cu} ${flash_attention_cuda_kernels_cu})
181177
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu})
182-
183-
list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip})
184-
list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip})
185178
endif()
186179

187180
if(USE_MEM_EFF_ATTENTION)
@@ -291,34 +284,10 @@ endif()
291284

292285
if(USE_ROCM)
293286
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
294-
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
295-
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
296-
list(APPEND ATen_HIP_SRCS
297-
${ATen_HIP_SRCS}
298-
${hip_hip}
299-
${native_hip_hip}
300-
${native_nested_hip_hip}
301-
${native_sparse_hip_hip}
302-
${native_quantized_hip_hip}
303-
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
304-
)
287+
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} ${native_hip_hip} ${native_nested_hip_hip} ${native_sparse_hip_hip} ${native_quantized_hip_hip} ${native_transformers_hip_hip})
305288
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
306-
list(APPEND all_hip_cpp
307-
${native_nested_hip_cpp}
308-
${native_sparse_hip_cpp}
309-
${native_quantized_hip_cpp}
310-
${native_transformers_hip_cpp}
311-
${native_quantized_cudnn_hip_cpp}
312-
${hip_cpp}
313-
${native_hip_cpp}
314-
${native_hip_linalg_cpp}
315-
${cuda_generated_sources}
316-
${ATen_HIP_SRCS}
317-
${native_miopen_cpp}
318-
${native_cudnn_hip_cpp}
319-
${miopen_cpp}
320-
${all_hip_cpp}
321-
)
289+
set(all_hip_cpp ${native_nested_hip_cpp} ${native_sparse_hip_cpp} ${native_quantized_hip_cpp} ${native_transformers_hip_cpp} ${native_quantized_cudnn_hip_cpp} ${hip_cpp} ${native_hip_cpp} ${native_hip_linalg_cpp} ${cuda_generated_sources} ${ATen_HIP_SRCS})
290+
set(all_hip_cpp ${native_miopen_cpp} ${native_cudnn_hip_cpp} ${miopen_cpp} ${all_hip_cpp})
322291
endif()
323292

324293
list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..)

aten/src/ATen/native/transformers/attention.cpp

+1-10
Original file line numberDiff line numberDiff line change
@@ -445,13 +445,6 @@ int64_t _fused_sdp_choice_meta(
445445
bool is_causal,
446446
c10::optional<double> scale) {
447447
auto query_key_set = query_.key_set();
448-
#if defined(USE_ROCM)
449-
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
450-
if (has_rocm) {
451-
auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale);
452-
return choice_int;
453-
}
454-
#else
455448
bool has_cuda = query_key_set.has(c10::DispatchKey::CUDA);
456449
if (has_cuda) {
457450
auto choice_int = _fused_sdp_choice_stub(
@@ -465,7 +458,6 @@ int64_t _fused_sdp_choice_meta(
465458
scale);
466459
return choice_int;
467460
}
468-
#endif
469461
return static_cast<int64_t>(sdp::SDPBackend::math);
470462
}
471463
namespace {
@@ -633,8 +625,7 @@ Tensor scaled_dot_product_attention(
633625
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
634626
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
635627
if (query_.device().type() == DeviceType::CUDA
636-
|| query_.device().type() == DeviceType::CPU
637-
|| query_.device().type() == DeviceType::HIP){
628+
|| query_.device().type() == DeviceType::CPU){
638629
choice_int = _fused_sdp_choice_stub(query_.device().type(),
639630
query_, key, value, attn_mask_, dropout_p, is_causal, scale);
640631
}

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

-27
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include <c10/util/Exception.h>
1515
#include <c10/util/env.h>
1616
#include <c10/util/irange.h>
17-
#include <c10/util/CallOnce.h>
1817

1918
#include <c10/core/SymInt.h>
2019
#include <c10/util/string_view.h>
@@ -182,31 +181,6 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
182181
using sm80 = SMVersion<8, 0>;
183182
using sm90 = SMVersion<9, 0>;
184183
auto dprops = at::cuda::getCurrentDeviceProperties();
185-
#if USE_ROCM
186-
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
187-
static const char *over_arch = [] {
188-
auto rc = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
189-
if (rc) {
190-
TORCH_WARN("SDPA functions only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
191-
"Later changes to this environment variable with os.environ "
192-
"(or other methods) will not affect SDPA function's behavior.");
193-
}
194-
return rc;
195-
}();
196-
const char* real_arch = dprops->gcnArchName;
197-
const char* arch = over_arch ? over_arch : real_arch;
198-
if (mi200 != arch) {
199-
if (debug) {
200-
TORCH_WARN(
201-
"Flash attention only supports gpu architecture gfx90a, for now. Attempting to run on a ",
202-
arch,
203-
".",
204-
over_arch ? " This is overrided by PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE. Real architecture is " : "",
205-
over_arch ? real_arch : "");
206-
}
207-
return false;
208-
}
209-
#else
210184
if (!check_sm_version<sm80, sm90>(dprops)) {
211185
if (debug) {
212186
TORCH_WARN(
@@ -218,7 +192,6 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
218192
}
219193
return false;
220194
}
221-
#endif
222195
return true;
223196
}
224197

0 commit comments

Comments
 (0)