@@ -34,7 +34,7 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
34
34
public:
35
35
explicit CuDNNConvolutionLayer (const LayerParameter& param)
36
36
: ConvolutionLayer<Dtype>(param), handles_setup_(false ),
37
- backward_passed_ctr_( 0 ) {}
37
+ use_algo_seeker_( true ), use_modest_workspace_( true ) {}
38
38
virtual void LayerSetUp (const vector<Blob<Dtype>*>& bottom,
39
39
const vector<Blob<Dtype>*>& top);
40
40
virtual void Reshape (const vector<Blob<Dtype>*>& bottom,
@@ -65,7 +65,24 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
65
65
size_t *workspace_bwd_data_sizes_;
66
66
size_t *workspace_bwd_filter_sizes_;
67
67
GPUMemory::Workspace workspace;
68
- int backward_passed_ctr_;
68
+
69
+ private:
70
+ bool use_algo_seeker_;
71
+ bool use_modest_workspace_;
72
+ void FindExConvAlgo (const vector<Blob<Dtype>*>& bottom,
73
+ const vector<Blob<Dtype>*>& top,
74
+ const size_t workspace_bytes);
75
+ void GetConvAlgo (const vector<Blob<Dtype>*>& bottom,
76
+ const vector<Blob<Dtype>*>& top,
77
+ const size_t workspace_bytes);
78
+
79
+ vector<cudnnTensorDescriptor_t> cached_bottom_descs_;
80
+ vector<cudnnConvolutionDescriptor_t> cached_conv_descs_;
81
+ bool IsBottomDescChanged (const vector<Blob<Dtype>*>& bottom);
82
+ bool IsConvDescChanged (const vector<Blob<Dtype>*>& bottom);
83
+
84
+ bool use_reshape_;
85
+ bool initialized_cached_descs_;
69
86
};
70
87
#endif
71
88
0 commit comments