1
+ /*
2
+ * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include " LayerNormPlugin.h"
18
+ #include < iostream>
19
+ #include < cstdint>
20
+ #include < cuda_fp16.h>
21
+ using namespace nvinfer1 ;
22
+ // typedef unsigned short half;
23
+
24
+ PluginFieldCollection LayerNormPluginCreator::fc_{};
25
+ std::vector<PluginField> LayerNormPluginCreator::attr_;
26
+
27
+
28
+ template <typename T>
29
+ __global__ void layerNormKernel (T *pInput, T *pOutput)
30
+ {
31
+ const int tx = threadIdx .x , index = blockIdx .x * 768 + threadIdx .x ;
32
+
33
+ __shared__ T temp[128 ];
34
+ // 这里会不会越界
35
+ T value0 = pInput[index ];
36
+ T value1 = pInput[index + 128 ];
37
+ T value2 = pInput[index + 256 ];
38
+ T value3 = pInput[index + 384 ];
39
+ T value4 = pInput[index + 512 ];
40
+ T value5 = pInput[index + 640 ];
41
+ temp[tx] = value0 + value1 + value2 + value3 + value4 + value5;
42
+ __syncthreads ();
43
+
44
+ for (int stride = 64 ; stride >= 1 ; stride /= 2 )
45
+ {
46
+ if (tx < stride)
47
+ {
48
+ temp[tx] += temp[tx + stride];
49
+ }
50
+ __syncthreads ();
51
+ }
52
+ T mean = temp[0 ] / (T) 768.0 ;
53
+ __syncthreads ();
54
+
55
+ temp[tx] = (value0 - mean) * (value0 - mean) + (value1 - mean) * (value1 - mean) + (value2 - mean) * (value2 - mean) +
56
+ (value3 - mean) * (value3 - mean) + (value4 - mean) * (value4 - mean) + (value5 - mean) * (value5 - mean);
57
+ __syncthreads ();
58
+
59
+ for (int stride = 64 ; stride >= 1 ; stride /= 2 )
60
+ {
61
+ if (tx < stride)
62
+ {
63
+ temp[tx] += temp[tx + stride];
64
+ }
65
+ __syncthreads ();
66
+ }
67
+ T var = temp[0 ] / (T) 768.0 ;
68
+ T eps = 6e-6 ;
69
+ pOutput[index ] = (value0 - mean) * (T) rsqrtf (var + eps);
70
+ pOutput[index + 128 ] = (value1 - mean) * (T) rsqrtf (var + eps);
71
+ pOutput[index + 256 ] = (value2 - mean) * (T) rsqrtf (var + eps);
72
+ pOutput[index + 384 ] = (value3 - mean) * (T) rsqrtf (var + eps);
73
+ pOutput[index + 512 ] = (value4 - mean) * (T) rsqrtf (var + eps);
74
+ pOutput[index + 640 ] = (value5 - mean) * (T) rsqrtf (var + eps);
75
+ }
76
+
77
+ void layerNormCompute (const int nBlock, cudaStream_t stream, const float * input, float * output)
78
+ {
79
+ layerNormKernel<float > <<<nBlock, 128 , 0 , stream>>> ((float *)input, (float *)output);
80
+ }
81
+
82
+ void layerNormCompute (const int nBlock, cudaStream_t stream, const __half* input, __half* output)
83
+ {
84
+ layerNormKernel<__half> <<<nBlock, 128 , 0 , stream>>> ((__half *)input, (__half *)output);
85
+ }
86
+
87
+ int32_t LayerNormPlugin::enqueue (const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void * const * inputs, void * const * outputs, void * workspace, cudaStream_t stream) noexcept
88
+ {
89
+ const int nBlock = inputDesc[0 ].dims .d [0 ] * inputDesc[0 ].dims .d [1 ];
90
+ // const int dim = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
91
+
92
+ // cast float to half
93
+ // const __half* input = static_cast<const __half*>(inputs[0]);
94
+ // __half* output = static_cast<__half*>(outputs[0]);
95
+
96
+ const float * input = static_cast <const float *>(inputs[0 ]);
97
+ float * output = static_cast <float *>(outputs[0 ]);
98
+
99
+
100
+ layerNormCompute (nBlock, stream, input, output);
101
+ // layerNormKernel<float> <<<nBlock, 128, 0, stream>>>((float *)inputs[0], (float *)outputs[0]);
102
+ return 0 ;
103
+ }
104
+
105
+ REGISTER_TENSORRT_PLUGIN (LayerNormPluginCreator);
106
+
107
+
108
+ // #include "LayerNormPlugin.h"
109
+ // #include <iostream>
110
+ // using namespace nvinfer1;
111
+
112
+ // PluginFieldCollection LayerNormPluginCreator::fc_{};
113
+ // std::vector<PluginField> LayerNormPluginCreator::attr_;
114
+
115
+
116
+ // __global__ void layerNormKernel(float *pInput, float *pOutput)
117
+ // {
118
+ // const int tx = threadIdx.x, index = blockIdx.x * 768 + threadIdx.x;
119
+
120
+ // __shared__ float temp[128];
121
+ // // 这里会不会越界
122
+ // float value0 = pInput[index];
123
+ // float value1 = pInput[index + 128];
124
+ // float value2 = pInput[index + 256];
125
+ // float value3 = pInput[index + 384];
126
+ // float value4 = pInput[index + 512];
127
+ // float value5 = pInput[index + 640];
128
+ // temp[tx] = value0 + value1 + value2 + value3 + value4 + value5;
129
+ // __syncthreads();
130
+
131
+ // for (int stride = 64; stride >= 1; stride /= 2)
132
+ // {
133
+ // if (tx < stride)
134
+ // {
135
+ // temp[tx] += temp[tx + stride];
136
+ // }
137
+ // __syncthreads();
138
+ // }
139
+ // float mean = temp[0] / 768;
140
+ // __syncthreads();
141
+
142
+ // temp[tx] = (value0 - mean) * (value0 - mean) + (value1 - mean) * (value1 - mean) + (value2 - mean) * (value2 - mean) +
143
+ // (value3 - mean) * (value3 - mean) + (value4 - mean) * (value4 - mean) + (value5 - mean) * (value5 - mean);
144
+ // __syncthreads();
145
+
146
+ // for (int stride = 64; stride >= 1; stride /= 2)
147
+ // {
148
+ // if (tx < stride)
149
+ // {
150
+ // temp[tx] += temp[tx + stride];
151
+ // }
152
+ // __syncthreads();
153
+ // }
154
+ // float var = temp[0] / 768;
155
+
156
+ // pOutput[index] = (value0 - mean) * rsqrtf(var + 6e-6);
157
+ // pOutput[index + 128] = (value1 - mean) * rsqrtf(var + 6e-6);
158
+ // pOutput[index + 256] = (value2 - mean) * rsqrtf(var + 6e-6);
159
+ // pOutput[index + 384] = (value3 - mean) * rsqrtf(var + 6e-6);
160
+ // pOutput[index + 512] = (value4 - mean) * rsqrtf(var + 6e-6);
161
+ // pOutput[index + 640] = (value5 - mean) * rsqrtf(var + 6e-6);
162
+ // }
163
+
164
+
165
+ // int32_t LayerNormPlugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
166
+ // {
167
+ // const int nBlock = inputDesc[0].dims.d[0] * inputDesc[0].dims.d[1];
168
+ // // const int dim = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
169
+
170
+
171
+ // layerNormKernel <<<nBlock, 128, 0, stream>>>((float *)inputs[0], (float *)outputs[0]);
172
+ // return 0;
173
+ // }
174
+
175
+ // REGISTER_TENSORRT_PLUGIN(LayerNormPluginCreator);
0 commit comments