@@ -56,19 +56,21 @@ void GPUMemoryManager::destroy() {
56
56
mode_ = NO_POOL;
57
57
}
58
58
59
- void GPUMemoryManager::allocate (void ** ptr, size_t size, cudaStream_t stream) {
59
+ bool GPUMemoryManager::try_allocate (void ** ptr, size_t size,
60
+ cudaStream_t stream) {
60
61
CHECK ((ptr) != NULL );
62
+ cudaError_t status = cudaSuccess, last_err = cudaSuccess;
61
63
switch (mode_) {
62
64
case CUB_POOL:
63
- if (cub_allocator->DeviceAllocate (ptr, size, stream) != cudaSuccess) {
65
+ // Clean Cache & Retry logic is inside now
66
+ status = cub_allocator->DeviceAllocate (ptr, size, stream);
67
+ // If there was a retry and it succeeded we get good status here but
68
+ // we need to clean up last error...
69
+ last_err = cudaGetLastError ();
70
+ // ...and update the dev info if something was wrong
71
+ if (status != cudaSuccess || last_err != cudaSuccess) {
64
72
int cur_device;
65
73
CUDA_CHECK (cudaGetDevice (&cur_device));
66
- // free all cached memory (for all devices), synchrionize
67
- cudaDeviceSynchronize ();
68
- cudaThreadSynchronize ();
69
- cub_allocator->FreeAllCached ();
70
- cudaDeviceSynchronize ();
71
- cudaThreadSynchronize ();
72
74
// Refresh per-device saved values.
73
75
for (int i = 0 ; i < dev_info_.size (); ++i) {
74
76
// only query devices that were initialized
@@ -80,16 +82,13 @@ void GPUMemoryManager::allocate(void** ptr, size_t size, cudaStream_t stream) {
80
82
}
81
83
}
82
84
}
83
- // Retry once
84
- CUDA_CHECK (cub_allocator->DeviceAllocate (ptr, size, stream));
85
85
}
86
- // If retry succeeds we need to clean up last error
87
- cudaGetLastError ();
88
86
break ;
89
87
default :
90
- CUDA_CHECK ( cudaMalloc (ptr, size) );
88
+ status = cudaMalloc (ptr, size);
91
89
break ;
92
90
}
91
+ return status == cudaSuccess;
93
92
}
94
93
95
94
void GPUMemoryManager::deallocate (void * ptr, cudaStream_t stream) {
@@ -172,7 +171,7 @@ void GPUMemoryManager::GetInfo(size_t* free_mem, size_t* total_mem) {
172
171
CUDA_CHECK (cudaGetDevice (&cur_device));
173
172
*total_mem = dev_info_[cur_device].total_ ;
174
173
// Free memory is initial free memory minus outstanding allocations.
175
- // Assuming we only allocate via GPUMemoryManager since its constructon .
174
+ // Assuming we only allocate via GPUMemoryManager since its construction .
176
175
*free_mem = dev_info_[cur_device].free_ -
177
176
cub_allocator->cached_bytes [cur_device].live ;
178
177
break ;
0 commit comments