Skip to content

Commit 788300c

Browse files
xw285cornellpytorchmergebot
authored andcommitted
[cudnn] Support v8 API in fbcode (pytorch#96512)
Summary: It turns out we never turn on cudnn v8 API which blocks bf16 conv. Enable the new v8 API Test Plan: buck run mode/dev-nosan scripts/xdwang/example:fc_pytorch Reviewed By: ngimel Differential Revision: D43784279 Pull Request resolved: pytorch#96512 Approved by: https://github.com/malfet
1 parent fe0afc5 commit 788300c

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

aten/src/ATen/native/cudnn/Conv_v8.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ namespace at { namespace native {
4545
namespace {
4646

4747
// TODO: remove duplicate code in Conv_v7.cpp
48-
constexpr size_t operator "" _TiB(unsigned long long n) {
48+
constexpr int64_t operator "" _TiB(unsigned long long n) {
4949
return size_t(n) << 40;
5050
}
5151

@@ -323,12 +323,12 @@ auto get_generator_sources(const cudnnBackendDescriptorType_t& desc, const Tenso
323323
}
324324
}
325325

326-
size_t get_available_workspace() {
326+
int64_t get_available_workspace() {
327327
int device;
328328
C10_CUDA_CHECK(cudaGetDevice(&device));
329329
size_t max_block_size = 0;
330330
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
331-
return max_block_size;
331+
return static_cast<int64_t>(max_block_size);
332332
}
333333

334334
static nlohmann::json errata_json_handle;
@@ -347,10 +347,10 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera
347347
return plan_errata_exception(handle, plan.getTag());
348348
};
349349
auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function);
350-
size_t max_block_size = get_available_workspace();
351-
size_t max_workspace_size = 0u;
350+
int64_t max_block_size = get_available_workspace();
351+
int64_t max_workspace_size = 0;
352352
std::for_each(plans.begin(), plans.end(), [&] (cudnn_frontend::ExecutionPlan& plan) {
353-
size_t curr_workspace_size = plan.getWorkspaceSize();
353+
int64_t curr_workspace_size = plan.getWorkspaceSize();
354354
if (curr_workspace_size <= max_block_size) {
355355
if (curr_workspace_size > max_workspace_size) {
356356
max_workspace_size = plan.getWorkspaceSize();
@@ -373,7 +373,7 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera
373373
if (remove_invalid) {
374374
cudnn_frontend::executionPlans_t new_valid_plans;
375375
for (auto &plan : valid_plans) {
376-
if (static_cast<size_t>(plan.getWorkspaceSize()) <= max_workspace_size) {
376+
if (plan.getWorkspaceSize() <= max_workspace_size) {
377377
new_valid_plans.emplace_back(std::move(plan));
378378
}
379379
}

aten/src/ATen/native/cudnn/Macros.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// Note: The version below should not actually be 8000. Instead, it should
66
// be whatever version of cuDNN that v8 API work with PyTorch correctly.
77
// The version is set to 8000 today for convenience of debugging.
8-
#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
8+
#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8300
99
#define HAS_CUDNN_V8() true
1010
#else
1111
#define HAS_CUDNN_V8() false

defs.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ default_compiler_flags = [
3434
"-DTH_INDEX_BASE=0",
3535
"-DMAGMA_V2",
3636
"-DNO_CUDNN_DESTROY_HANDLE",
37+
"-DUSE_EXPERIMENTAL_CUDNN_V8_API", # enable cudnn v8 api
3738
"-DUSE_FBGEMM",
3839
"-DUSE_QNNPACK",
3940
"-DUSE_PYTORCH_QNNPACK",

0 commit comments

Comments
 (0)