diff --git a/src/kernels/layernorm_kernels.cu b/src/kernels/layernorm_kernels.cu index 8787d458..0a336657 100644 --- a/src/kernels/layernorm_kernels.cu +++ b/src/kernels/layernorm_kernels.cu @@ -285,4 +285,22 @@ void invoke_layernorm_kernel(float* out, layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); } +template <> +void invoke_layernorm_kernel(half* out, + const half* input, + const half* weight, + const half* bias, + const float epsilon, + int m, + int n) { + int half_n = n / 2; + half2* out_ptr = (half2*)out; + const half2* input_ptr = (const half2*)input; + const half2* weight_ptr = (const half2*)weight; + const half2* bias_ptr = (const half2*)bias; + + dim3 block(std::min(half_n, 1024)); + layer_norm_kernel + <<>>(out_ptr, input_ptr, weight_ptr, bias_ptr, epsilon, half_n); +} } // namespace llm::kernel \ No newline at end of file diff --git a/src/kernels/layernrom_kernels_test.cu b/src/kernels/layernrom_kernels_test.cu index 94c0b5b0..df472225 100644 --- a/src/kernels/layernrom_kernels_test.cu +++ b/src/kernels/layernrom_kernels_test.cu @@ -1,148 +1,10 @@ #include -#include -#include "layernorm_kernels.h" -#include #include +#include -template -void printMatrix(T* a, int m, int n) { - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - printf("%f ", (float)a[i * n + j]); - } - puts(""); - } - puts(""); -} - -template <> -void printMatrix(float* a, int m, int n) { - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - printf("%f ", a[i * n + j]); - } - puts(""); - } - puts(""); -} - -template <> -void printMatrix(half* a, int m, int n) { - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - printf("%f ", __half2float(a[i * n + j])); - } - puts(""); - } - puts(""); -} - -template <> -void printMatrix(half2* a, int m, int n) { - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - printf( - "%f %f ", __half2float(a[i * n + j].x), __half2float(a[i * n + j].y)); - } - puts(""); - } - puts(""); -} - -void layernorm_kernel_half2_test() { - float epsilon = 1e-6; - int m = 2; - int n = 2; - - half2* out = (half2*)malloc(m * n * sizeof(half2)); - half2* input = (half2*)malloc(m * n * sizeof(half2)); - half2* weight = (half2*)malloc(m * n * sizeof(half2)); - half2* bias = (half2*)malloc(m * n * sizeof(half2)); - - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - input[i * n + j] = half2(__float2half((float)(i * n + j * 2)), - __float2half((float)(i * n + j * 2 + 1))); - weight[i * n + j] = half2(__float2half(1.), __float2half(1.)); - bias[i * n + j] = half2(__float2half(0.), __float2half(0.)); - } - } - - half2* dout; - half2* dinput; - half2* dweight; - half2* dbias; - cudaMalloc((void**)&dout, sizeof(half2) * m * n); - cudaMalloc((void**)&dinput, sizeof(half2) * m * n); - cudaMalloc((void**)&dweight, sizeof(half2) * m * n); - cudaMalloc((void**)&dbias, sizeof(half2) * m * n); - - cudaMemcpy(dinput, input, sizeof(half2) * m * n, cudaMemcpyHostToDevice); - cudaMemcpy(dweight, weight, sizeof(half2) * m * n, cudaMemcpyHostToDevice); - cudaMemcpy(dbias, bias, sizeof(half2) * m * n, cudaMemcpyHostToDevice); - - llm::kernel::invoke_layernorm_kernel( - dout, dinput, dweight, dbias, epsilon, m, n); - - cudaMemcpy(out, dout, sizeof(half2) * m * n, cudaMemcpyDeviceToHost); - - printf("---------- test half2 layernorm kernel -----------\n"); - printf("input:\n"); - printMatrix(input, m, n); - printf("weights:\n"); - printMatrix(weight, m, n); - printf("bias:\n"); - printMatrix(bias, m, n); - printf("outputs:\n"); - printMatrix(out, m, n); -} - -void layernorm_kernel_float_test() { - float epsilon = 1e-6; - int m = 2; - int n = 4; - - float* out = (float*)malloc(m * n * sizeof(float)); - float* input = (float*)malloc(m * n * sizeof(float)); - float* weight = (float*)malloc(m * n * sizeof(float)); - float* bias = (float*)malloc(m * n * sizeof(float)); - - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - input[i * n + j] = (float)(i * n + j); - weight[i * n + j] = 1.; - bias[i * n + j] = 0.; - } - } - - float* dout; - float* dinput; - float* dweight; - float* dbias; - cudaMalloc((void**)&dout, sizeof(float) * m * n); - cudaMalloc((void**)&dinput, sizeof(float) * m * n); - cudaMalloc((void**)&dweight, sizeof(float) * m * n); - cudaMalloc((void**)&dbias, sizeof(float) * m * n); - - cudaMemcpy(dinput, input, sizeof(float) * m * n, cudaMemcpyHostToDevice); - cudaMemcpy(dweight, weight, sizeof(float) * m * n, cudaMemcpyHostToDevice); - cudaMemcpy(dbias, bias, sizeof(float) * m * n, cudaMemcpyHostToDevice); - - llm::kernel::invoke_layernorm_kernel( - dout, dinput, dweight, dbias, epsilon, m, n); +#include - cudaMemcpy(out, dout, sizeof(float) * m * n, cudaMemcpyDeviceToHost); - - printf("---------- test float layernorm kernel -----------\n"); - printf("input:\n"); - printMatrix(input, m, n); - printf("weights:\n"); - printMatrix(weight, m, n); - printf("bias:\n"); - printMatrix(bias, m, n); - printf("outputs:\n"); - printMatrix(out, m, n); -} +#include "layernorm_kernels.h" TEST(NormalizationKernelTest, LayernormFloatTest) { float epsilon = 1e-6; @@ -152,7 +14,10 @@ TEST(NormalizationKernelTest, LayernormFloatTest) { auto input = torch::randn({m, n}); auto weight = torch::randn({n}); auto bias = torch::randn({n}); - auto desired_out = torch::nn::functional::layer_norm(input, torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias(bias)); + auto desired_out = torch::nn::functional::layer_norm( + input, + torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias( + bias)); float* hout = (float*)malloc(m * n * sizeof(float)); float* hinput = input.data_ptr(); @@ -184,4 +49,65 @@ TEST(NormalizationKernelTest, LayernormFloatTest) { cudaFree(dinput); cudaFree(dweight); cudaFree(dbias); -} \ No newline at end of file +} + +TEST(NormalizationKernelTest, LayernormHalfTest) { + float epsilon = 1e-6; + int m = 4; + int n = 512; + auto input = torch::randn({m, n}); + auto weight = torch::randn({n}); + auto bias = torch::randn({n}); + auto desired_out = torch::nn::functional::layer_norm( + input, + torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias( + bias)); + + half* hout = (half*)malloc(m * n * sizeof(half)); + half* hinput = (half*)malloc(m * n * sizeof(half)); + half* hweight = (half*)malloc(n * sizeof(half)); + half* hbias = (half*)malloc(n * sizeof(half)); + + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + hinput[i * n + j] = __float2half(input[i][j].item()); + } + } + for (int i = 0; i < weight.numel(); i++) + hweight[i] = __float2half(weight[i].item()); + for (int i = 0; i < bias.numel(); i++) + hbias[i] = __float2half(bias[i].item()); + + half* dout; + half* dinput; + half* dweight; + half* dbias; + cudaMalloc((void**)&dout, sizeof(half) * m * n); + cudaMalloc((void**)&dinput, sizeof(half) * m * n); + cudaMalloc((void**)&dweight, sizeof(half) * n); + cudaMalloc((void**)&dbias, sizeof(half) * n); + + cudaMemcpy(dinput, hinput, sizeof(half) * m * n, cudaMemcpyHostToDevice); + cudaMemcpy(dweight, hweight, sizeof(half) * n, cudaMemcpyHostToDevice); + cudaMemcpy(dbias, hbias, sizeof(half) * n, cudaMemcpyHostToDevice); + + llm::kernel::invoke_layernorm_kernel( + dout, dinput, dweight, dbias, epsilon, m, n); + + cudaMemcpy(hout, dout, sizeof(half) * m * n, cudaMemcpyDeviceToHost); + + float* float_hout = (float*)malloc(m * n * sizeof(float)); + for (int i = 0; i < m * n; i++) float_hout[i] = __half2float(hout[i]); + + auto out = torch::from_blob(float_hout, {m, n}); + EXPECT_TRUE(torch::allclose(out, desired_out, 0.05, 1e-3)); + free(hout); + free(hinput); + free(hweight); + free(hbias); + free(float_hout); + cudaFree(dout); + cudaFree(dinput); + cudaFree(dweight); + cudaFree(dbias); +} diff --git a/src/kernels/reduce_kernel_utils.cuh b/src/kernels/reduce_kernel_utils.cuh index edb9d953..16e077ad 100644 --- a/src/kernels/reduce_kernel_utils.cuh +++ b/src/kernels/reduce_kernel_utils.cuh @@ -198,9 +198,8 @@ struct TopK { // operator for cub::BlockReduce to get topk across a thread block template -__device__ __forceinline__ TopK reduce_topk_op( - const TopK& a, - const TopK& b) { +__device__ __forceinline__ TopK reduce_topk_op(const TopK& a, + const TopK& b) { TopK res = a; for (int i = 0; i < K; ++i) { res.insert(b.u[i], b.p[i]);