Skip to content

Commit

Permalink
use gtest library rewrite layernorm kernel unitest
Browse files Browse the repository at this point in the history
  • Loading branch information
Xianzhe Dong committed Apr 27, 2024
1 parent 14a99f7 commit 6ff1738
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 23 deletions.
26 changes: 9 additions & 17 deletions src/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,15 @@ cc_library(
torch
)

# cc_test(
# NAME
# layernorm_kernels_test
# SRCS
# layernrom_kernels_test.cu
# layernorm_kernels.cu
# DEPS
# DEFINES
# )
cc_binary(
NAME
layernorm_kernels_test
SRCS
layernrom_kernels_test.cu
layernorm_kernels.cu
DEPS
torch
cc_test(
NAME
layernorm_kernels_test
SRCS
layernrom_kernels_test.cu
layernorm_kernels.cu
DEPS
torch
GTest::gtest_main
)

add_subdirectory(flash_attn)
Expand Down
70 changes: 64 additions & 6 deletions src/kernels/layernrom_kernels_test.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <cuda_fp16.h>

#include <cstdio>

#include "layernorm_kernels.h"
#include <torch/nn/functional.h>
#include <gtest/gtest.h>

template <typename T>
void printMatrix(T* a, int m, int n) {
Expand All @@ -15,6 +15,28 @@ void printMatrix(T* a, int m, int n) {
puts("");
}

template <>
void printMatrix<float>(float* a, int m, int n) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
printf("%f ", a[i * n + j]);
}
puts("");
}
puts("");
}

template <>
void printMatrix<half>(half* a, int m, int n) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
printf("%f ", __half2float(a[i * n + j]));
}
puts("");
}
puts("");
}

template <>
void printMatrix<half2>(half2* a, int m, int n) {
for (int i = 0; i < m; i++) {
Expand Down Expand Up @@ -122,8 +144,44 @@ void layernorm_kernel_float_test() {
printMatrix<float>(out, m, n);
}

int main() {
layernorm_kernel_float_test();
layernorm_kernel_half2_test();
return 0;
TEST(NormalizationKernelTest, LayernormFloatTest) {
float epsilon = 1e-6;
int m = 32;
int n = 512;

auto input = torch::randn({m, n});
auto weight = torch::randn({n});
auto bias = torch::randn({n});
auto desired_out = torch::nn::functional::layer_norm(input, torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias(bias));

float* hout = (float*)malloc(m * n * sizeof(float));
float* hinput = input.data_ptr<float>();
float* hweight = weight.data_ptr<float>();
float* hbias = bias.data_ptr<float>();

float* dout;
float* dinput;
float* dweight;
float* dbias;
cudaMalloc((void**)&dout, sizeof(float) * m * n);
cudaMalloc((void**)&dinput, sizeof(float) * m * n);
cudaMalloc((void**)&dweight, sizeof(float) * n);
cudaMalloc((void**)&dbias, sizeof(float) * n);

cudaMemcpy(dinput, hinput, sizeof(float) * m * n, cudaMemcpyHostToDevice);
cudaMemcpy(dweight, hweight, sizeof(float) * n, cudaMemcpyHostToDevice);
cudaMemcpy(dbias, hbias, sizeof(float) * n, cudaMemcpyHostToDevice);

llm::kernel::invoke_layernorm_kernel<float>(
dout, dinput, dweight, dbias, epsilon, m, n);

cudaMemcpy(hout, dout, sizeof(float) * m * n, cudaMemcpyDeviceToHost);

auto out = torch::from_blob(hout, {m, n});
EXPECT_TRUE(torch::allclose(out, desired_out, 1e-3, 1e-5));
free(hout);
cudaFree(dout);
cudaFree(dinput);
cudaFree(dweight);
cudaFree(dbias);
}

0 comments on commit 6ff1738

Please sign in to comment.