Skip to content

Commit

Permalink
added layernorm kernel half2 unit test using gtest library
Browse files Browse the repository at this point in the history
  • Loading branch information
Xianzhe Dong committed Apr 27, 2024
1 parent 6ff1738 commit bc9f7e2
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 146 deletions.
18 changes: 18 additions & 0 deletions src/kernels/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -285,4 +285,22 @@ void invoke_layernorm_kernel<float>(float* out,
layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n);
}

template <>
void invoke_layernorm_kernel<half>(half* out,
const half* input,
const half* weight,
const half* bias,
const float epsilon,
int m,
int n) {
int half_n = n / 2;
half2* out_ptr = (half2*)out;
const half2* input_ptr = (const half2*)input;
const half2* weight_ptr = (const half2*)weight;
const half2* bias_ptr = (const half2*)bias;

dim3 block(std::min(half_n, 1024));
layer_norm_kernel<half2>
<<<m, block>>>(out_ptr, input_ptr, weight_ptr, bias_ptr, epsilon, half_n);
}
} // namespace llm::kernel
212 changes: 69 additions & 143 deletions src/kernels/layernrom_kernels_test.cu
Original file line number Diff line number Diff line change
@@ -1,148 +1,10 @@
#include <cuda_fp16.h>
#include <cstdio>
#include "layernorm_kernels.h"
#include <torch/nn/functional.h>
#include <gtest/gtest.h>
#include <torch/nn/functional.h>

template <typename T>
void printMatrix(T* a, int m, int n) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
printf("%f ", (float)a[i * n + j]);
}
puts("");
}
puts("");
}

template <>
void printMatrix<float>(float* a, int m, int n) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
printf("%f ", a[i * n + j]);
}
puts("");
}
puts("");
}

template <>
void printMatrix<half>(half* a, int m, int n) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
printf("%f ", __half2float(a[i * n + j]));
}
puts("");
}
puts("");
}

template <>
void printMatrix<half2>(half2* a, int m, int n) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
printf(
"%f %f ", __half2float(a[i * n + j].x), __half2float(a[i * n + j].y));
}
puts("");
}
puts("");
}

void layernorm_kernel_half2_test() {
float epsilon = 1e-6;
int m = 2;
int n = 2;

half2* out = (half2*)malloc(m * n * sizeof(half2));
half2* input = (half2*)malloc(m * n * sizeof(half2));
half2* weight = (half2*)malloc(m * n * sizeof(half2));
half2* bias = (half2*)malloc(m * n * sizeof(half2));

for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
input[i * n + j] = half2(__float2half((float)(i * n + j * 2)),
__float2half((float)(i * n + j * 2 + 1)));
weight[i * n + j] = half2(__float2half(1.), __float2half(1.));
bias[i * n + j] = half2(__float2half(0.), __float2half(0.));
}
}

half2* dout;
half2* dinput;
half2* dweight;
half2* dbias;
cudaMalloc((void**)&dout, sizeof(half2) * m * n);
cudaMalloc((void**)&dinput, sizeof(half2) * m * n);
cudaMalloc((void**)&dweight, sizeof(half2) * m * n);
cudaMalloc((void**)&dbias, sizeof(half2) * m * n);

cudaMemcpy(dinput, input, sizeof(half2) * m * n, cudaMemcpyHostToDevice);
cudaMemcpy(dweight, weight, sizeof(half2) * m * n, cudaMemcpyHostToDevice);
cudaMemcpy(dbias, bias, sizeof(half2) * m * n, cudaMemcpyHostToDevice);

llm::kernel::invoke_layernorm_kernel<half2>(
dout, dinput, dweight, dbias, epsilon, m, n);

cudaMemcpy(out, dout, sizeof(half2) * m * n, cudaMemcpyDeviceToHost);

printf("---------- test half2 layernorm kernel -----------\n");
printf("input:\n");
printMatrix<half2>(input, m, n);
printf("weights:\n");
printMatrix<half2>(weight, m, n);
printf("bias:\n");
printMatrix<half2>(bias, m, n);
printf("outputs:\n");
printMatrix<half2>(out, m, n);
}

void layernorm_kernel_float_test() {
float epsilon = 1e-6;
int m = 2;
int n = 4;

float* out = (float*)malloc(m * n * sizeof(float));
float* input = (float*)malloc(m * n * sizeof(float));
float* weight = (float*)malloc(m * n * sizeof(float));
float* bias = (float*)malloc(m * n * sizeof(float));

for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
input[i * n + j] = (float)(i * n + j);
weight[i * n + j] = 1.;
bias[i * n + j] = 0.;
}
}

float* dout;
float* dinput;
float* dweight;
float* dbias;
cudaMalloc((void**)&dout, sizeof(float) * m * n);
cudaMalloc((void**)&dinput, sizeof(float) * m * n);
cudaMalloc((void**)&dweight, sizeof(float) * m * n);
cudaMalloc((void**)&dbias, sizeof(float) * m * n);

cudaMemcpy(dinput, input, sizeof(float) * m * n, cudaMemcpyHostToDevice);
cudaMemcpy(dweight, weight, sizeof(float) * m * n, cudaMemcpyHostToDevice);
cudaMemcpy(dbias, bias, sizeof(float) * m * n, cudaMemcpyHostToDevice);

llm::kernel::invoke_layernorm_kernel<float>(
dout, dinput, dweight, dbias, epsilon, m, n);
#include <cstdio>

cudaMemcpy(out, dout, sizeof(float) * m * n, cudaMemcpyDeviceToHost);

printf("---------- test float layernorm kernel -----------\n");
printf("input:\n");
printMatrix<float>(input, m, n);
printf("weights:\n");
printMatrix<float>(weight, m, n);
printf("bias:\n");
printMatrix<float>(bias, m, n);
printf("outputs:\n");
printMatrix<float>(out, m, n);
}
#include "layernorm_kernels.h"

TEST(NormalizationKernelTest, LayernormFloatTest) {
float epsilon = 1e-6;
Expand All @@ -152,7 +14,10 @@ TEST(NormalizationKernelTest, LayernormFloatTest) {
auto input = torch::randn({m, n});
auto weight = torch::randn({n});
auto bias = torch::randn({n});
auto desired_out = torch::nn::functional::layer_norm(input, torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias(bias));
auto desired_out = torch::nn::functional::layer_norm(
input,
torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias(
bias));

float* hout = (float*)malloc(m * n * sizeof(float));
float* hinput = input.data_ptr<float>();
Expand Down Expand Up @@ -184,4 +49,65 @@ TEST(NormalizationKernelTest, LayernormFloatTest) {
cudaFree(dinput);
cudaFree(dweight);
cudaFree(dbias);
}
}

TEST(NormalizationKernelTest, LayernormHalfTest) {
float epsilon = 1e-6;
int m = 4;
int n = 512;
auto input = torch::randn({m, n});
auto weight = torch::randn({n});
auto bias = torch::randn({n});
auto desired_out = torch::nn::functional::layer_norm(
input,
torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias(
bias));

half* hout = (half*)malloc(m * n * sizeof(half));
half* hinput = (half*)malloc(m * n * sizeof(half));
half* hweight = (half*)malloc(n * sizeof(half));
half* hbias = (half*)malloc(n * sizeof(half));

for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
hinput[i * n + j] = __float2half(input[i][j].item<float>());
}
}
for (int i = 0; i < weight.numel(); i++)
hweight[i] = __float2half(weight[i].item<float>());
for (int i = 0; i < bias.numel(); i++)
hbias[i] = __float2half(bias[i].item<float>());

half* dout;
half* dinput;
half* dweight;
half* dbias;
cudaMalloc((void**)&dout, sizeof(half) * m * n);
cudaMalloc((void**)&dinput, sizeof(half) * m * n);
cudaMalloc((void**)&dweight, sizeof(half) * n);
cudaMalloc((void**)&dbias, sizeof(half) * n);

cudaMemcpy(dinput, hinput, sizeof(half) * m * n, cudaMemcpyHostToDevice);
cudaMemcpy(dweight, hweight, sizeof(half) * n, cudaMemcpyHostToDevice);
cudaMemcpy(dbias, hbias, sizeof(half) * n, cudaMemcpyHostToDevice);

llm::kernel::invoke_layernorm_kernel<half>(
dout, dinput, dweight, dbias, epsilon, m, n);

cudaMemcpy(hout, dout, sizeof(half) * m * n, cudaMemcpyDeviceToHost);

float* float_hout = (float*)malloc(m * n * sizeof(float));
for (int i = 0; i < m * n; i++) float_hout[i] = __half2float(hout[i]);

auto out = torch::from_blob(float_hout, {m, n});
EXPECT_TRUE(torch::allclose(out, desired_out, 0.05, 1e-3));
free(hout);
free(hinput);
free(hweight);
free(hbias);
free(float_hout);
cudaFree(dout);
cudaFree(dinput);
cudaFree(dweight);
cudaFree(dbias);
}
5 changes: 2 additions & 3 deletions src/kernels/reduce_kernel_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,8 @@ struct TopK {

// operator for cub::BlockReduce to get topk across a thread block
template <typename T, int K>
__device__ __forceinline__ TopK<T, K> reduce_topk_op(
const TopK<T, K>& a,
const TopK<T, K>& b) {
__device__ __forceinline__ TopK<T, K> reduce_topk_op(const TopK<T, K>& a,
const TopK<T, K>& b) {
TopK<T, K> res = a;
for (int i = 0; i < K; ++i) {
res.insert(b.u[i], b.p[i]);
Expand Down

0 comments on commit bc9f7e2

Please sign in to comment.