Skip to content

Commit 24892b4

Browse files
committed
Use _v7 cudnnGetConvolution.*Algorithm APIs
1 parent 6edc48a commit 24892b4

File tree

2 files changed

+68
-45
lines changed

2 files changed

+68
-45
lines changed

tensorflow/stream_executor/cuda/cuda_dnn.cc

+68
Original file line numberDiff line numberDiff line change
@@ -2526,6 +2526,28 @@ port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
25262526
const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
25272527
const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
25282528
size_t memory_limit_bytes) {
2529+
#if CUDNN_VERSION >= 8000
2530+
const int num_requested_algos = 5;
2531+
int num_returned_algos = 0;
2532+
cudnnConvolutionFwdAlgoPerf_t perf_results[num_requested_algos];
2533+
2534+
RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
2535+
cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2536+
output_nd.handle(), num_requested_algos, &num_returned_algos,
2537+
perf_results));
2538+
2539+
size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2540+
for (int r=0; r<num_returned_algos; r++) {
2541+
if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2542+
perf_results[r].memory <= mem_limit) {
2543+
return perf_results[r].algo;
2544+
}
2545+
}
2546+
return port::Status(
2547+
port::error::INTERNAL,
2548+
"cudnnGetConvolutionForwardAlgorithm_v7 returned "
2549+
"no suitable algorithms. This could be a cudnn bug.");
2550+
#else
25292551
cudnnConvolutionFwdPreference_t preference =
25302552
specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
25312553
: CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
@@ -2534,6 +2556,7 @@ port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
25342556
cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
25352557
output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
25362558
return algo_to_use;
2559+
#endif
25372560
}
25382561

25392562
port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
@@ -2544,6 +2567,28 @@ GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
25442567
const CudnnTensorDescriptor& output_nd,
25452568
bool specify_workspace_limit,
25462569
size_t memory_limit_bytes) {
2570+
#if CUDNN_VERSION >= 8000
2571+
const int num_requested_algos = 5;
2572+
int num_returned_algos = 0;
2573+
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_requested_algos];
2574+
2575+
RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(
2576+
cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2577+
input_nd.handle(), num_requested_algos, &num_returned_algos,
2578+
perf_results));
2579+
2580+
size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2581+
for (int r=0; r<num_returned_algos; r++) {
2582+
if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2583+
perf_results[r].memory <= mem_limit) {
2584+
return perf_results[r].algo;
2585+
}
2586+
}
2587+
return port::Status(
2588+
port::error::INTERNAL,
2589+
"cudnnGetConvolutionBackwardDataAlgorithm_v7 returned "
2590+
"no suitable algorithms. This could be a cudnn bug.");
2591+
#else
25472592
cudnnConvolutionBwdDataPreference_t preference =
25482593
specify_workspace_limit
25492594
? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
@@ -2553,6 +2598,7 @@ GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
25532598
cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
25542599
input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
25552600
return algo_to_use;
2601+
#endif
25562602
}
25572603

25582604
port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
@@ -2563,6 +2609,27 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
25632609
const CudnnTensorDescriptor& output_nd,
25642610
bool specify_workspace_limit,
25652611
size_t memory_limit_bytes) {
2612+
#if CUDNN_VERSION >= 8000
2613+
const int num_requested_algos = 5;
2614+
int num_returned_algos = 0;
2615+
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_requested_algos];
2616+
2617+
RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
2618+
cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2619+
filter.handle(), num_requested_algos, &num_returned_algos, perf_results));
2620+
2621+
size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2622+
for (int r=0; r<num_returned_algos; r++) {
2623+
if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2624+
perf_results[r].memory <= mem_limit) {
2625+
return perf_results[r].algo;
2626+
}
2627+
}
2628+
return port::Status(
2629+
port::error::INTERNAL,
2630+
"cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned "
2631+
"no suitable algorithms. This could be a cudnn bug.");
2632+
#else
25662633
cudnnConvolutionBwdFilterPreference_t preference =
25672634
specify_workspace_limit
25682635
? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
@@ -2572,6 +2639,7 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
25722639
cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
25732640
filter.handle(), preference, memory_limit_bytes, &algo_to_use));
25742641
return algo_to_use;
2642+
#endif
25752643
}
25762644

25772645
port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(

tensorflow/stream_executor/cuda/cudnn_8_0.inc

-45
Original file line numberDiff line numberDiff line change
@@ -65,21 +65,6 @@ cudnnFindConvolutionForwardAlgorithmEx(cudnnHandle_t handle,
6565
return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes);
6666
}
6767

68-
cudnnStatus_t CUDNNWINAPI
69-
cudnnGetConvolutionForwardAlgorithm(cudnnHandle_t handle,
70-
const cudnnTensorDescriptor_t xDesc,
71-
const cudnnFilterDescriptor_t wDesc,
72-
const cudnnConvolutionDescriptor_t convDesc,
73-
const cudnnTensorDescriptor_t yDesc,
74-
cudnnConvolutionFwdPreference_t preference,
75-
size_t memoryLimitInBytes,
76-
cudnnConvolutionFwdAlgo_t *algo) {
77-
using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, cudnnConvolutionFwdAlgo_t *);
78-
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm");
79-
if (!func_ptr) return GetSymbolNotFoundError();
80-
return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, memoryLimitInBytes, algo);
81-
}
82-
8368
cudnnStatus_t CUDNNWINAPI
8469
cudnnGetConvolutionForwardAlgorithm_v7(cudnnHandle_t handle,
8570
const cudnnTensorDescriptor_t srcDesc,
@@ -211,21 +196,6 @@ cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle,
211196
return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes);
212197
}
213198

214-
cudnnStatus_t CUDNNWINAPI
215-
cudnnGetConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle,
216-
const cudnnTensorDescriptor_t xDesc,
217-
const cudnnTensorDescriptor_t dyDesc,
218-
const cudnnConvolutionDescriptor_t convDesc,
219-
const cudnnFilterDescriptor_t dwDesc,
220-
cudnnConvolutionBwdFilterPreference_t preference,
221-
size_t memoryLimitInBytes,
222-
cudnnConvolutionBwdFilterAlgo_t *algo) {
223-
using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, size_t, cudnnConvolutionBwdFilterAlgo_t *);
224-
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm");
225-
if (!func_ptr) return GetSymbolNotFoundError();
226-
return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, memoryLimitInBytes, algo);
227-
}
228-
229199
cudnnStatus_t CUDNNWINAPI
230200
cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle,
231201
const cudnnTensorDescriptor_t srcDesc,
@@ -318,21 +288,6 @@ cudnnFindConvolutionBackwardDataAlgorithmEx(cudnnHandle_t handle,
318288
return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes);
319289
}
320290

321-
cudnnStatus_t CUDNNWINAPI
322-
cudnnGetConvolutionBackwardDataAlgorithm(cudnnHandle_t handle,
323-
const cudnnFilterDescriptor_t wDesc,
324-
const cudnnTensorDescriptor_t dyDesc,
325-
const cudnnConvolutionDescriptor_t convDesc,
326-
const cudnnTensorDescriptor_t dxDesc,
327-
cudnnConvolutionBwdDataPreference_t preference,
328-
size_t memoryLimitInBytes,
329-
cudnnConvolutionBwdDataAlgo_t *algo) {
330-
using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, size_t, cudnnConvolutionBwdDataAlgo_t *);
331-
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm");
332-
if (!func_ptr) return GetSymbolNotFoundError();
333-
return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, memoryLimitInBytes, algo);
334-
}
335-
336291
cudnnStatus_t CUDNNWINAPI
337292
cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnnHandle_t handle,
338293
const cudnnFilterDescriptor_t filterDesc,

0 commit comments

Comments
 (0)