From 83fd099d67970d720bcb31d5438d398162ae690d Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 11 Mar 2024 14:32:31 -0400 Subject: [PATCH 01/25] refactor matmul files --- crates/burn-jit/src/kernel/matmul/base.rs | 105 ++++++++++++++++-- crates/burn-jit/src/kernel/matmul/mod.rs | 18 ++- .../kernel/matmul/{tiling2d => }/padding.rs | 0 .../matmul/{mem_coalescing.rs => simple.rs} | 4 +- .../{tiling2d/unpadded.rs => tiling2d.rs} | 4 +- .../src/kernel/matmul/tiling2d/base.rs | 91 --------------- .../src/kernel/matmul/tiling2d/mod.rs | 14 --- .../{tiling2d/vec4.rs => tiling2d_padded.rs} | 6 +- .../burn-jit/src/kernel/matmul/tune/base.rs | 47 +++----- 9 files changed, 133 insertions(+), 156 deletions(-) rename crates/burn-jit/src/kernel/matmul/{tiling2d => }/padding.rs (100%) rename crates/burn-jit/src/kernel/matmul/{mem_coalescing.rs => simple.rs} (98%) rename crates/burn-jit/src/kernel/matmul/{tiling2d/unpadded.rs => tiling2d.rs} (93%) delete mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d/base.rs delete mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d/mod.rs rename crates/burn-jit/src/kernel/matmul/{tiling2d/vec4.rs => tiling2d_padded.rs} (86%) diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 3eb9425878..56e2a0d57b 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,8 +1,20 @@ -use crate::{tensor::JitTensor, JitElement, Runtime}; +use burn_compute::server::Handle; +use burn_tensor::Shape; + +use crate::{ + compute::{DynamicKernel, WorkGroup}, + kernel::{build_info, into_contiguous, DynamicKernelSource}, + ops::numeric::empty_device, + tensor::JitTensor, + JitElement, Runtime, +}; use super::{ - init_matmul_output, matmul_autotune, matmul_mem_coalescing, - unpadded::matmul_tiling_2d_unpadded, vec4::matmul_tiling_2d_vec4, + init_matmul_output, matmul_autotune, matmul_simple, + padding::{crop, pad_round, PaddingOutput}, + shape_out, + tiling2d::matmul_tiling_2d, + tiling2d_padded::matmul_tiling_2d_padded, }; /// The strategy to be used when launching a matmul kernel. @@ -25,7 +37,6 @@ pub enum MatmulStrategy { Autotune, } -#[cfg(feature = "autotune")] #[cfg(not(feature = "autotune"))] impl Default for MatmulStrategy { fn default() -> Self { @@ -42,17 +53,97 @@ pub fn matmul( match strategy { MatmulStrategy::Simple { grid_x, grid_y } => { let out = init_matmul_output(&lhs, &rhs); - matmul_mem_coalescing(lhs, rhs, out, grid_x, grid_y) + matmul_simple(lhs, rhs, out, grid_x, grid_y) } MatmulStrategy::Tiling2d => { let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d_unpadded(lhs, rhs, out) + matmul_tiling_2d(lhs, rhs, out) } MatmulStrategy::Tiling2dPadded => { let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d_vec4(lhs, rhs, out) + matmul_tiling_2d_padded(lhs, rhs, out) } #[cfg(feature = "autotune")] MatmulStrategy::Autotune => matmul_autotune(lhs, rhs), } } + +pub(crate) const B_M: usize = 64; +pub(crate) const B_N: usize = 64; +pub(crate) const B_K: usize = 32; +pub(crate) const WORKGROUP_SIZE: usize = 16; + +pub(super) fn make_workgroup(output_shape: &Shape) -> WorkGroup { + let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32; + let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32; + let mut num_blocks_z = 1; + for i in 0..D - 2 { + num_blocks_z *= output_shape.dims[i]; + } + + WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) +} + +pub(super) fn make_info_handle( + lhs: &JitTensor, + rhs: &JitTensor, + output: &JitTensor, +) -> Handle { + let info = build_info(&[lhs, rhs, output]); + rhs.client.create(bytemuck::cast_slice(&info)) +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn matmul_tiling_2d_launch< + R: Runtime, + E: JitElement, + const D: usize, + K: DynamicKernelSource + 'static, +>( + lhs: JitTensor, + rhs: JitTensor, + output: JitTensor, + kernel: K, +) -> JitTensor { + // A tensor may need to be padded, in which case it will implicitly become contiguous + // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. + // If batches were swapped among themselves, or if the last two dims are transposed, the underlying + // kernel handles it without needing to turn it into contiguous. + let round_lhs = pad_round::(lhs, B_M, B_K); + let lhs = match round_lhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_lhs.into_tensor(), + }; + let round_rhs = pad_round::(rhs, B_K, B_N); + let rhs = match round_rhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_rhs.into_tensor(), + }; + + let rounded_output_shape = shape_out(&lhs, &rhs); + + let rounded_output = empty_device( + rhs.client.clone(), + rhs.device.clone(), + rounded_output_shape.clone(), + ); + + let workgroup = make_workgroup(&rounded_output_shape); + let info_handle = make_info_handle(&lhs, &rhs, &rounded_output); + + lhs.client.execute( + Box::new(DynamicKernel::new(kernel, workgroup)), + &[ + &lhs.handle, + &rhs.handle, + &rounded_output.handle, + &info_handle, + ], + ); + + crop(rounded_output, output) +} diff --git a/crates/burn-jit/src/kernel/matmul/mod.rs b/crates/burn-jit/src/kernel/matmul/mod.rs index 920db368b0..7c187f71f4 100644 --- a/crates/burn-jit/src/kernel/matmul/mod.rs +++ b/crates/burn-jit/src/kernel/matmul/mod.rs @@ -1,13 +1,23 @@ mod base; -mod mem_coalescing; -mod tiling2d; +mod simple; mod tune; /// Contains utilitary for matmul operation pub mod utils; pub use base::*; -pub use mem_coalescing::*; -pub use tiling2d::*; +pub use simple::*; pub use tune::*; pub use utils::*; + +#[cfg(feature = "export_tests")] +#[allow(missing_docs)] +pub mod padding; + +#[cfg(not(feature = "export_tests"))] +mod padding; + +pub mod tiling2d; +pub mod tiling2d_padded; +pub use tiling2d::*; +pub use tiling2d_padded::*; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d/padding.rs b/crates/burn-jit/src/kernel/matmul/padding.rs similarity index 100% rename from crates/burn-jit/src/kernel/matmul/tiling2d/padding.rs rename to crates/burn-jit/src/kernel/matmul/padding.rs diff --git a/crates/burn-jit/src/kernel/matmul/mem_coalescing.rs b/crates/burn-jit/src/kernel/matmul/simple.rs similarity index 98% rename from crates/burn-jit/src/kernel/matmul/mem_coalescing.rs rename to crates/burn-jit/src/kernel/matmul/simple.rs index 8fb305338d..0408f0414a 100644 --- a/crates/burn-jit/src/kernel/matmul/mem_coalescing.rs +++ b/crates/burn-jit/src/kernel/matmul/simple.rs @@ -213,11 +213,11 @@ pub fn matmul_mem_coalescing_default( rhs: JitTensor, out: JitTensor, ) -> JitTensor { - matmul_mem_coalescing::(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT) + matmul_simple::(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT) } /// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes -pub fn matmul_mem_coalescing( +pub fn matmul_simple( lhs: JitTensor, rhs: JitTensor, out: JitTensor, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d/unpadded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs similarity index 93% rename from crates/burn-jit/src/kernel/matmul/tiling2d/unpadded.rs rename to crates/burn-jit/src/kernel/matmul/tiling2d.rs index cf88d70cac..1e4c04d9aa 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d/unpadded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -15,7 +15,7 @@ use super::base::{make_info_handle, make_workgroup, B_K, B_M, B_N, WORKGROUP_SIZ kernel_wgsl!( MatmulTiling2DUnpaddedRaw, - "../../../template/matmul/blocktiling_2d/unpadded.wgsl" + "../../template/matmul/blocktiling_2d/unpadded.wgsl" ); #[derive(new, Debug)] @@ -45,7 +45,7 @@ impl DynamicKernelSource for MatmulTiling2DUnpadded { /// Matrix multiplication using tiling 2d algorithm with /// vec4 primitive on both lhs and rhs, with no padding needed -pub fn matmul_tiling_2d_unpadded( +pub fn matmul_tiling_2d( lhs: JitTensor, rhs: JitTensor, out: JitTensor, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d/base.rs deleted file mode 100644 index 0167464fb6..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d/base.rs +++ /dev/null @@ -1,91 +0,0 @@ -use super::padding::{crop, pad_round, PaddingOutput}; -use crate::{ - compute::{DynamicKernel, WorkGroup}, - element::JitElement, - kernel::{build_info, into_contiguous, matmul::utils::shape_out, DynamicKernelSource}, - ops::numeric::empty_device, - tensor::JitTensor, - Runtime, -}; -use burn_compute::server::Handle; -use burn_tensor::Shape; - -pub(crate) const B_M: usize = 64; -pub(crate) const B_N: usize = 64; -pub(crate) const B_K: usize = 32; -pub(crate) const WORKGROUP_SIZE: usize = 16; - -pub(super) fn make_workgroup(output_shape: &Shape) -> WorkGroup { - let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32; - let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32; - let mut num_blocks_z = 1; - for i in 0..D - 2 { - num_blocks_z *= output_shape.dims[i]; - } - - WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) -} - -pub(super) fn make_info_handle( - lhs: &JitTensor, - rhs: &JitTensor, - output: &JitTensor, -) -> Handle { - let info = build_info(&[lhs, rhs, output]); - rhs.client.create(bytemuck::cast_slice(&info)) -} - -#[allow(clippy::too_many_arguments)] -pub(super) fn matmul_tiling_2d_launch< - R: Runtime, - E: JitElement, - const D: usize, - K: DynamicKernelSource + 'static, ->( - lhs: JitTensor, - rhs: JitTensor, - output: JitTensor, - kernel: K, -) -> JitTensor { - // A tensor may need to be padded, in which case it will implicitly become contiguous - // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. - // If batches were swapped among themselves, or if the last two dims are transposed, the underlying - // kernel handles it without needing to turn it into contiguous. - let round_lhs = pad_round::(lhs, B_M, B_K); - let lhs = match round_lhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_lhs.into_tensor(), - }; - let round_rhs = pad_round::(rhs, B_K, B_N); - let rhs = match round_rhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_rhs.into_tensor(), - }; - - let rounded_output_shape = shape_out(&lhs, &rhs); - - let rounded_output = empty_device( - rhs.client.clone(), - rhs.device.clone(), - rounded_output_shape.clone(), - ); - - let workgroup = make_workgroup(&rounded_output_shape); - let info_handle = make_info_handle(&lhs, &rhs, &rounded_output); - - lhs.client.execute( - Box::new(DynamicKernel::new(kernel, workgroup)), - &[ - &lhs.handle, - &rhs.handle, - &rounded_output.handle, - &info_handle, - ], - ); - - crop(rounded_output, output) -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d/mod.rs deleted file mode 100644 index 02f60c8148..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -mod base; - -#[cfg(feature = "export_tests")] -#[allow(missing_docs)] -pub mod padding; - -#[cfg(not(feature = "export_tests"))] -mod padding; - -/// WGSL vec4 primitives are used on left and right hand tensor, -/// padding is avoided through the use of conditions in the kernel -pub mod unpadded; -/// WGSL vec4 primitives are used on left and right hand tensor -pub mod vec4; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d/vec4.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs similarity index 86% rename from crates/burn-jit/src/kernel/matmul/tiling2d/vec4.rs rename to crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index 24d1b91ed9..b21ac0a437 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d/vec4.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -9,7 +9,7 @@ use std::marker::PhantomData; kernel_wgsl!( MatmulTiling2Dvec4Raw, - "../../../template/matmul/blocktiling_2d/vec4.wgsl" + "../../template/matmul/blocktiling_2d/vec4.wgsl" ); #[derive(new, Debug)] @@ -37,9 +37,7 @@ impl DynamicKernelSource for MatmulTiling2Dvec4 { } } -/// Matrix multiplication using tiling 2d algorithm with -/// vec4 primitive on both lhs and rhs -pub fn matmul_tiling_2d_vec4( +pub fn matmul_tiling_2d_padded( lhs: JitTensor, rhs: JitTensor, out: JitTensor, diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 7208118644..4bfcb9f3a3 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -50,22 +50,14 @@ impl AutotuneOperationSet AutotuneOperationSet, fastest_index: usize) -> Box { match fastest_index { - 0 => Box::new(MemoryCoalescingMatmulDefault::new( - self.lhs, self.rhs, self.out, - )), - 1 => Box::new(MemoryCoalescingMatmulW16x16::new( - self.lhs, self.rhs, self.out, - )), - 2 => Box::new(Vec4TilingMatmulDefault::new(self.lhs, self.rhs, self.out)), - 3 => Box::new(Vec4TilingMatmulUnpaddedDefault::new( - self.lhs, self.rhs, self.out, - )), + 0 => Box::new(SimpleMatmul::new(self.lhs, self.rhs, self.out)), + 1 => Box::new(SimpleMatmul16x16::new(self.lhs, self.rhs, self.out)), + 2 => Box::new(Tiling2DMatmul::new(self.lhs, self.rhs, self.out)), + 3 => Box::new(Tiling2DMatmulPadded::new(self.lhs, self.rhs, self.out)), _ => panic!("Fastest index is out of bound"), } } @@ -134,23 +120,20 @@ macro_rules! matmul_tune_ops { // Potentially better for small matrices. matmul_tune_ops!( - MemoryCoalescingMatmulDefault, + SimpleMatmul, crate::kernel::matmul::matmul_mem_coalescing_default ); // Potentially better for small matrices. -matmul_tune_ops!(MemoryCoalescingMatmulW16x16, |lhs, rhs, out| { - crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16) +matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| { + crate::kernel::matmul::matmul_simple(lhs, rhs, out, 16, 16) }); // Probably the fastest when fixed sizes. matmul_tune_ops!( - Vec4TilingMatmulDefault, - crate::kernel::matmul::vec4::matmul_tiling_2d_vec4 + Tiling2DMatmulPadded, + crate::kernel::matmul::matmul_tiling_2d_padded ); -// Probably the fastest otherwise. -matmul_tune_ops!( - Vec4TilingMatmulUnpaddedDefault, - crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded -); +// Probably the fastest in the general case +matmul_tune_ops!(Tiling2DMatmul, crate::kernel::matmul::matmul_tiling_2d); From d0b2ce47264cc7efcc95b8381bbda3e681844f35 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 11 Mar 2024 15:48:24 -0400 Subject: [PATCH 02/25] wip refactor matmul --- crates/burn-jit/src/kernel/matmul/base.rs | 185 ++++++++---- crates/burn-jit/src/kernel/matmul/tiling2d.rs | 279 ++++++++++++++++-- .../src/kernel/matmul/tiling2d_padded.rs | 5 + .../burn-jit/src/kernel/matmul/tune/base.rs | 11 +- 4 files changed, 387 insertions(+), 93 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 56e2a0d57b..2dac456454 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -10,10 +10,9 @@ use crate::{ }; use super::{ - init_matmul_output, matmul_autotune, matmul_simple, + init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, padding::{crop, pad_round, PaddingOutput}, shape_out, - tiling2d::matmul_tiling_2d, tiling2d_padded::matmul_tiling_2d_padded, }; @@ -22,15 +21,37 @@ use super::{ pub enum MatmulStrategy { /// A simple kernel will be used with memory coalescing optimization. Simple { - /// Grad size x + /// Number of invocations in x grid_x: usize, - /// Grad size y + /// Number of invocations in y grid_y: usize, }, /// A tiling 2d kernel will be used, with support for any matrix size without padding. - Tiling2d, + Tiling2d { + /// Number of invocations in x + grid_x: usize, + /// Number of invocations in y + grid_y: usize, + /// Block size along dimension of lhs + block_size_m: usize, + /// Block size along common dimension + block_size_k: usize, + /// Block size along dimension of rhs + block_size_n: usize, + }, /// A tiling 2d kernel will be used, with support for any matrix size with padding. - Tiling2dPadded, + Tiling2dPadded { + /// Number of invocations in x + grid_x: usize, + /// Number of invocations in y + grid_y: usize, + /// Block size along dimension of lhs + block_size_m: usize, + /// Block size along common dimension + block_size_k: usize, + /// Block size along dimension of rhs + block_size_n: usize, + }, #[cfg(feature = "autotune")] /// Using autotune to chose the best kernel based on runtime information. #[default] @@ -55,13 +76,43 @@ pub fn matmul( let out = init_matmul_output(&lhs, &rhs); matmul_simple(lhs, rhs, out, grid_x, grid_y) } - MatmulStrategy::Tiling2d => { + MatmulStrategy::Tiling2d { + grid_x, + grid_y, + block_size_m, + block_size_n, + block_size_k, + } => { let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d(lhs, rhs, out) + matmul_tiling_2d( + lhs, + rhs, + out, + grid_x, + grid_y, + block_size_m, + block_size_n, + block_size_k, + ) } - MatmulStrategy::Tiling2dPadded => { + MatmulStrategy::Tiling2dPadded { + grid_x, + grid_y, + block_size_m, + block_size_n, + block_size_k, + } => { let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d_padded(lhs, rhs, out) + matmul_tiling_2d_padded( + lhs, + rhs, + out, + grid_x, + grid_y, + block_size_m, + block_size_n, + block_size_k, + ) } #[cfg(feature = "autotune")] MatmulStrategy::Autotune => matmul_autotune(lhs, rhs), @@ -73,16 +124,16 @@ pub(crate) const B_N: usize = 64; pub(crate) const B_K: usize = 32; pub(crate) const WORKGROUP_SIZE: usize = 16; -pub(super) fn make_workgroup(output_shape: &Shape) -> WorkGroup { - let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32; - let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32; - let mut num_blocks_z = 1; - for i in 0..D - 2 { - num_blocks_z *= output_shape.dims[i]; - } +// pub(super) fn make_workgroup(output_shape: &Shape) -> WorkGroup { +// let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32; +// let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32; +// let mut num_blocks_z = 1; +// for i in 0..D - 2 { +// num_blocks_z *= output_shape.dims[i]; +// } - WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) -} +// WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) +// } pub(super) fn make_info_handle( lhs: &JitTensor, @@ -105,45 +156,67 @@ pub(super) fn matmul_tiling_2d_launch< output: JitTensor, kernel: K, ) -> JitTensor { + todo!() // A tensor may need to be padded, in which case it will implicitly become contiguous // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. // If batches were swapped among themselves, or if the last two dims are transposed, the underlying // kernel handles it without needing to turn it into contiguous. - let round_lhs = pad_round::(lhs, B_M, B_K); - let lhs = match round_lhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_lhs.into_tensor(), - }; - let round_rhs = pad_round::(rhs, B_K, B_N); - let rhs = match round_rhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_rhs.into_tensor(), - }; - - let rounded_output_shape = shape_out(&lhs, &rhs); - - let rounded_output = empty_device( - rhs.client.clone(), - rhs.device.clone(), - rounded_output_shape.clone(), - ); - - let workgroup = make_workgroup(&rounded_output_shape); - let info_handle = make_info_handle(&lhs, &rhs, &rounded_output); - - lhs.client.execute( - Box::new(DynamicKernel::new(kernel, workgroup)), - &[ - &lhs.handle, - &rhs.handle, - &rounded_output.handle, - &info_handle, - ], - ); - - crop(rounded_output, output) + // let round_lhs = pad_round::(lhs, B_M, B_K); + // let lhs = match round_lhs { + // PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + // into_contiguous(tensor) + // } + // _ => round_lhs.into_tensor(), + // }; + // let round_rhs = pad_round::(rhs, B_K, B_N); + // let rhs = match round_rhs { + // PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + // into_contiguous(tensor) + // } + // _ => round_rhs.into_tensor(), + // }; + + // let rounded_output_shape = shape_out(&lhs, &rhs); + + // let rounded_output = empty_device( + // rhs.client.clone(), + // rhs.device.clone(), + // rounded_output_shape.clone(), + // ); + + // let workgroup = make_workgroup(&rounded_output_shape); + // let info_handle = make_info_handle(&lhs, &rhs, &rounded_output); + + // lhs.client.execute( + // Box::new(DynamicKernel::new(kernel, workgroup)), + // &[ + // &lhs.handle, + // &rhs.handle, + // &rounded_output.handle, + // &info_handle, + // ], + // ); + + // crop(rounded_output, output) +} + +pub(crate) fn launch_options( + lhs_shape: &Shape, + rhs_shape: &Shape, + output_shape: &Shape, + workgroup_size_x: usize, + workgroup_size_y: usize, +) -> WorkGroup { + let num_rows = lhs_shape.dims[D - 2]; + let num_cols = rhs_shape.dims[D - 1]; + + // set number of workgroups + let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; + let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; + let mut num_iter = 1; + for i in 0..D - 2 { + num_iter *= output_shape.dims[i]; + } + + WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32) } diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index 1e4c04d9aa..75f5af2559 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -1,45 +1,243 @@ use burn_tensor::Element; use crate::{ - compute::DynamicKernel, + codegen::{ + dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, + EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, + }, + compute::{DynamicKernel, WorkGroup}, element::JitElement, - kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource}, + gpu::{gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Scope, Variable}, + kernel::{ + into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, + }, tensor::JitTensor, Runtime, }; use std::marker::PhantomData; -use crate::kernel_wgsl; - -use super::base::{make_info_handle, make_workgroup, B_K, B_M, B_N, WORKGROUP_SIZE}; +use super::{ + base::{make_info_handle, B_K, B_M, B_N, WORKGROUP_SIZE}, + launch_options, +}; -kernel_wgsl!( - MatmulTiling2DUnpaddedRaw, - "../../template/matmul/blocktiling_2d/unpadded.wgsl" -); +use burn_tensor::Shape; #[derive(new, Debug)] struct MatmulTiling2DUnpadded { _elem: PhantomData, } -impl DynamicKernelSource for MatmulTiling2DUnpadded { +// impl DynamicKernelSource for MatmulTiling2DUnpadded { +// fn source(&self) -> SourceTemplate { +// MatmulTiling2DUnpaddedRaw::source() +// .register("b_m", B_M.to_string()) +// .register("b_n", B_N.to_string()) +// .register("b_k", B_K.to_string()) +// .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) +// .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) +// .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) +// .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) +// .register("workgroup_size_z", "1".to_string()) +// .register("elem", E::type_name()) +// .register("int", "i32") +// } + +// fn id(&self) -> String { +// std::format!("{:?}", self) +// } +// } + +#[derive(new, Debug)] +struct MatmulTiling2DEagerKernel { + workgroup_size_x: usize, + workgroup_size_y: usize, + block_size_m: usize, + block_size_k: usize, + block_size_n: usize, + _runtime: PhantomData, +} + +struct MatmulTiling2DShader { + variables: BinaryOperator, + block_size: usize, +} + +impl MatmulTiling2DShader { + fn expand(self, scope: &mut Scope) { + // Define out global variables. + let local_idx = Variable::LocalInvocationIndex; + let batch = Variable::GlobalInvocationIdZ; + let rank = Variable::Rank; + let block_size: Variable = self.block_size.into(); + + // Extract tensor variables. + let lhs = self.variables.lhs; + let rhs = self.variables.rhs; + let out = self.variables.out; + + // Define where we have to work on the current matrix. + let tmp_index = scope.create_local(Elem::UInt); + let batch_dims = scope.create_local(Elem::UInt); + let row = scope.create_local(Elem::UInt); + let col = scope.create_local(Elem::UInt); + + // Row position. + gpu!(scope, tmp_index = local_idx / block_size); + gpu!(scope, row = block_size * Variable::WorkgroupIdX); + gpu!(scope, row = row + tmp_index); + + // Col position. + gpu!(scope, tmp_index = local_idx % block_size); + gpu!(scope, col = block_size * Variable::WorkgroupIdY); + gpu!(scope, col = col + tmp_index); + + // Batch position. + gpu!(scope, batch_dims = rank - 2u32); + + // Define the matrix size. + let n_rows = scope.create_local(Elem::UInt); + let n_cols = scope.create_local(Elem::UInt); + let k = scope.create_local(Elem::UInt); + + // Number of rows. + gpu!(scope, n_rows = shape(out, batch_dims)); + + // Number of cols. + gpu!(scope, tmp_index = batch_dims + 1u32); + gpu!(scope, n_cols = shape(out, tmp_index)); + + // The dimension that is going to be squashed. + gpu!(scope, k = shape(lhs, tmp_index)); + + // Check if there is some work to be done. + let should_stop = scope.create_local(Elem::Bool); + gpu!(scope, should_stop = row >= n_rows); + gpu!(scope, if (should_stop).then(|scope| { + scope.register(Branch::Return); + })); + + gpu!(scope, should_stop = col >= n_cols); + gpu!(scope, if (should_stop).then(|scope| { + scope.register(Branch::Return); + })); + + // Calculate the batch offset. + let offset_lhs = scope.zero(Elem::UInt); + let offset_rhs = scope.zero(Elem::UInt); + let offset_output = scope.create_local(Elem::UInt); + + // Batch offset for the output. + gpu!(scope, offset_output = n_rows * n_cols); + gpu!(scope, offset_output = offset_output * batch); + + // Batch offset for the lhs & rhs matrices. + IndexOffsetGlobalWithLayout { + tensors: vec![lhs, rhs], + indexes: vec![offset_lhs, offset_rhs], + layout: out, + index_ref: offset_output, + dim_start: 0u32.into(), + dim_end: batch_dims, + } + .expand(scope); + + // Calculate the dot product (row X col). + let sum = scope.create_local(out.item()); + + // Initialize the sum to zero. + let zero: Variable = 0f32.into(); + gpu!(scope, sum = zero); + + // Loop over the k dimension. + gpu!( + scope, + range(0u32, k).for_each(|i, scope| { + let lhs_index = scope.create_local(Elem::UInt); + let rhs_index = scope.create_local(Elem::UInt); + + let lhs_value = scope.create_local(lhs.item()); + let rhs_value = scope.create_local(rhs.item()); + let out_value = scope.create_local(out.item()); + + gpu!(scope, lhs_index = row * k); + gpu!(scope, lhs_index = lhs_index + i); + gpu!(scope, lhs_index = lhs_index + offset_lhs); + + gpu!(scope, rhs_index = i * n_cols); + gpu!(scope, rhs_index = rhs_index + col); + gpu!(scope, rhs_index = rhs_index + offset_rhs); + + gpu!(scope, lhs_value = lhs[lhs_index]); + gpu!(scope, rhs_value = rhs[rhs_index]); + + gpu!(scope, out_value = lhs_value * rhs_value); + gpu!(scope, sum += out_value); + }) + ); + + let out_index = scope.create_local(Elem::UInt); + + gpu!(scope, out_index = row * n_cols); + gpu!(scope, out_index += col); + gpu!(scope, out_index += offset_output); + gpu!(scope, out[out_index] = sum); + } +} + +impl DynamicKernelSource for MatmulTiling2DEagerKernel { fn source(&self) -> SourceTemplate { - MatmulTiling2DUnpaddedRaw::source() - .register("b_m", B_M.to_string()) - .register("b_n", B_N.to_string()) - .register("b_k", B_K.to_string()) - .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) - .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) - .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_z", "1".to_string()) - .register("elem", E::type_name()) - .register("int", "i32") + let mut scope = gpu::Scope::root(); + let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into()); + let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into()); + let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into()); + + scope.write_global_custom(out); + + MatmulTiling2DShader { + variables: gpu::BinaryOperator { lhs, rhs, out }, + block_size: self.workgroup_size_x, + } + .expand(&mut scope); + + let lhs = InputInfo::Array { + item: gpu::Elem::Float.into(), + visibility: gpu::Visibility::Read, + }; + let rhs = InputInfo::Array { + item: gpu::Elem::Float.into(), + visibility: gpu::Visibility::Read, + }; + let out = OutputInfo::Array { + item: gpu::Elem::Float.into(), + }; + + let info = CompilationInfo { + inputs: vec![lhs, rhs], + outputs: vec![out], + scope, + }; + + let settings = CompilationSettings::default().workgroup_size(gpu::WorkgroupSize::new( + self.workgroup_size_x as u32, + self.workgroup_size_y as u32, + 1, + )); + let shader = Compilation::new(info).compile(settings); + let shader = ::compile(shader); + SourceTemplate::new(shader.to_string()) } fn id(&self) -> String { - std::format!("{:?}", self) + format!( + "{:?}x={}y={}b_m={}b_k={}b_n={}", + core::any::TypeId::of::(), + self.workgroup_size_x, + self.workgroup_size_y, + self.block_size_m, + self.block_size_k, + self.block_size_n, + ) } } @@ -49,7 +247,21 @@ pub fn matmul_tiling_2d( lhs: JitTensor, rhs: JitTensor, out: JitTensor, + workgroup_size_x: usize, + workgroup_size_y: usize, + block_size_m: usize, + block_size_k: usize, + block_size_n: usize, ) -> JitTensor { + let kernel = MatmulTiling2DEagerKernel::::new( + workgroup_size_x, + workgroup_size_y, + block_size_m, + block_size_k, + block_size_n, + ); + let client = lhs.client.clone(); + let lhs = match lhs.batch_swapped_with_row_col() { true => into_contiguous(lhs), false => lhs, @@ -59,16 +271,19 @@ pub fn matmul_tiling_2d( false => rhs, }; - let workgroup = make_workgroup(&out.shape); - let info_handle = make_info_handle(&lhs, &rhs, &out); - - lhs.client.execute( - Box::new(DynamicKernel::new( - MatmulTiling2DUnpadded::::new(), - workgroup, - )), - &[&lhs.handle, &rhs.handle, &out.handle, &info_handle], - ); + Execution::start(kernel, client) + .inputs(&[ + EagerHandle::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + ]) + .outputs(&[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)]) + .execute(WorkgroupLaunch::Custom(launch_options( + &lhs.shape, + &rhs.shape, + &out.shape, + workgroup_size_x, + workgroup_size_y, + ))); out } diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index b21ac0a437..720847ffa9 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -41,6 +41,11 @@ pub fn matmul_tiling_2d_padded( lhs: JitTensor, rhs: JitTensor, out: JitTensor, + workgroup_size_x: usize, + workgroup_size_y: usize, + block_size_m: usize, + block_size_k: usize, + block_size_n: usize, ) -> JitTensor { let kernel = MatmulTiling2Dvec4::::new(); matmul_tiling_2d_launch::(lhs, rhs, out, kernel) diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 4bfcb9f3a3..71e77e32be 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -130,10 +130,11 @@ matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| { }); // Probably the fastest when fixed sizes. -matmul_tune_ops!( - Tiling2DMatmulPadded, - crate::kernel::matmul::matmul_tiling_2d_padded -); +matmul_tune_ops!(Tiling2DMatmulPadded, |lhs, rhs, out| { + crate::kernel::matmul::matmul_tiling_2d_padded(lhs, rhs, out, 16, 16, 64, 32, 64) +}); // Probably the fastest in the general case -matmul_tune_ops!(Tiling2DMatmul, crate::kernel::matmul::matmul_tiling_2d); +matmul_tune_ops!(Tiling2DMatmul, |lhs, rhs, out| { + crate::kernel::matmul::matmul_tiling_2d(lhs, rhs, out, 16, 16, 64, 32, 64) +}); From 603df3f18187111c1a54aa35210ee5a0309fdbb9 Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 12 Mar 2024 11:52:32 -0400 Subject: [PATCH 03/25] everything is memco --- crates/burn-jit/src/codegen/compiler.rs | 1 + crates/burn-jit/src/kernel/matmul/base.rs | 241 ++++++--------- crates/burn-jit/src/kernel/matmul/mod.rs | 4 +- crates/burn-jit/src/kernel/matmul/simple.rs | 28 +- crates/burn-jit/src/kernel/matmul/tiling2d.rs | 86 ++---- .../src/kernel/matmul/tiling2d_padded.rs | 278 ++++++++++++++++-- .../burn-jit/src/kernel/matmul/tune/base.rs | 21 +- crates/burn-jit/src/tensor/base.rs | 1 + crates/burn-jit/src/tests/matmul.rs | 26 +- .../burn-wgpu/src/compiler/wgsl/compiler.rs | 4 + 10 files changed, 399 insertions(+), 291 deletions(-) diff --git a/crates/burn-jit/src/codegen/compiler.rs b/crates/burn-jit/src/codegen/compiler.rs index c06cd62be1..d5614607e4 100644 --- a/crates/burn-jit/src/codegen/compiler.rs +++ b/crates/burn-jit/src/codegen/compiler.rs @@ -22,4 +22,5 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug { fn compile(shader: gpu::ComputeShader) -> Self::Representation; /// The size of the given element in bytes. fn elem_size(elem: gpu::Elem) -> usize; + fn max_shared_memory_size() -> usize; } diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 2dac456454..2df48eceae 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,23 +1,86 @@ -use burn_compute::server::Handle; +use std::cmp::{max, min}; + use burn_tensor::Shape; -use crate::{ - compute::{DynamicKernel, WorkGroup}, - kernel::{build_info, into_contiguous, DynamicKernelSource}, - ops::numeric::empty_device, - tensor::JitTensor, - JitElement, Runtime, -}; +use crate::{compute::WorkGroup, tensor::JitTensor, Compiler, JitElement, Runtime}; use super::{ init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, - padding::{crop, pad_round, PaddingOutput}, - shape_out, tiling2d_padded::matmul_tiling_2d_padded, }; +#[derive(Debug, Clone)] +/// Tiling 2D parameters +pub struct Tiling2dConfig { + /// Number of invocations in x + pub grid_x: usize, + /// Number of invocations in y + pub grid_y: usize, + /// Block size along dimension of lhs + pub block_size_m: usize, + /// Block size along common dimension + pub block_size_k: usize, + /// Block size along dimension of rhs + pub block_size_n: usize, + /// Tile size along dimension of lhs + pub tile_size_m: usize, + /// Tile size along dimension of rhs + pub tile_size_n: usize, +} + +impl Tiling2dConfig { + #[allow(unused)] + fn new( + grid_x: usize, + grid_y: usize, + block_size_m: usize, + block_size_k: usize, + block_size_n: usize, + tile_size_m: usize, + tile_size_n: usize, + ) -> Self { + assert!(grid_x == f32::ceil(block_size_m as f32 / tile_size_m as f32) as usize); + assert!(grid_y == f32::ceil(block_size_n as f32 / tile_size_n as f32) as usize); + assert!( + block_size_k <= min(block_size_m, block_size_n), + "Not enough invocations to fill shared memory" + ); + assert!( + block_size_k * max(block_size_m, block_size_n) + <= ::max_shared_memory_size(), + "Shared memory limit will be busted. " + ); + assert!( + block_size_m % tile_size_m == 0 && block_size_n % tile_size_n == 0, + "Tile size must divide block size in m and n dimensions" + ); + Self { + grid_x, + grid_y, + block_size_m, + block_size_k, + block_size_n, + tile_size_m, + tile_size_n, + } + } +} + +impl Default for Tiling2dConfig { + fn default() -> Self { + Self { + grid_x: 16, + grid_y: 16, + block_size_m: 64, + block_size_k: 32, + block_size_n: 64, + tile_size_m: 4, + tile_size_n: 4, + } + } +} + /// The strategy to be used when launching a matmul kernel. -#[derive(Default)] pub enum MatmulStrategy { /// A simple kernel will be used with memory coalescing optimization. Simple { @@ -27,37 +90,24 @@ pub enum MatmulStrategy { grid_y: usize, }, /// A tiling 2d kernel will be used, with support for any matrix size without padding. - Tiling2d { - /// Number of invocations in x - grid_x: usize, - /// Number of invocations in y - grid_y: usize, - /// Block size along dimension of lhs - block_size_m: usize, - /// Block size along common dimension - block_size_k: usize, - /// Block size along dimension of rhs - block_size_n: usize, - }, + Tiling2d(Tiling2dConfig), /// A tiling 2d kernel will be used, with support for any matrix size with padding. - Tiling2dPadded { - /// Number of invocations in x - grid_x: usize, - /// Number of invocations in y - grid_y: usize, - /// Block size along dimension of lhs - block_size_m: usize, - /// Block size along common dimension - block_size_k: usize, - /// Block size along dimension of rhs - block_size_n: usize, - }, + Tiling2dPadded(Tiling2dConfig), #[cfg(feature = "autotune")] /// Using autotune to chose the best kernel based on runtime information. - #[default] Autotune, } +#[cfg(feature = "autotune")] +impl Default for MatmulStrategy { + fn default() -> Self { + MatmulStrategy::Simple { + grid_x: 32, + grid_y: 32, + } + } +} + #[cfg(not(feature = "autotune"))] impl Default for MatmulStrategy { fn default() -> Self { @@ -76,130 +126,19 @@ pub fn matmul( let out = init_matmul_output(&lhs, &rhs); matmul_simple(lhs, rhs, out, grid_x, grid_y) } - MatmulStrategy::Tiling2d { - grid_x, - grid_y, - block_size_m, - block_size_n, - block_size_k, - } => { + MatmulStrategy::Tiling2d(config) => { let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d( - lhs, - rhs, - out, - grid_x, - grid_y, - block_size_m, - block_size_n, - block_size_k, - ) + matmul_tiling_2d(lhs, rhs, out, config) } - MatmulStrategy::Tiling2dPadded { - grid_x, - grid_y, - block_size_m, - block_size_n, - block_size_k, - } => { + MatmulStrategy::Tiling2dPadded(config) => { let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d_padded( - lhs, - rhs, - out, - grid_x, - grid_y, - block_size_m, - block_size_n, - block_size_k, - ) + matmul_tiling_2d_padded(lhs, rhs, out, config) } #[cfg(feature = "autotune")] MatmulStrategy::Autotune => matmul_autotune(lhs, rhs), } } -pub(crate) const B_M: usize = 64; -pub(crate) const B_N: usize = 64; -pub(crate) const B_K: usize = 32; -pub(crate) const WORKGROUP_SIZE: usize = 16; - -// pub(super) fn make_workgroup(output_shape: &Shape) -> WorkGroup { -// let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32; -// let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32; -// let mut num_blocks_z = 1; -// for i in 0..D - 2 { -// num_blocks_z *= output_shape.dims[i]; -// } - -// WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) -// } - -pub(super) fn make_info_handle( - lhs: &JitTensor, - rhs: &JitTensor, - output: &JitTensor, -) -> Handle { - let info = build_info(&[lhs, rhs, output]); - rhs.client.create(bytemuck::cast_slice(&info)) -} - -#[allow(clippy::too_many_arguments)] -pub(super) fn matmul_tiling_2d_launch< - R: Runtime, - E: JitElement, - const D: usize, - K: DynamicKernelSource + 'static, ->( - lhs: JitTensor, - rhs: JitTensor, - output: JitTensor, - kernel: K, -) -> JitTensor { - todo!() - // A tensor may need to be padded, in which case it will implicitly become contiguous - // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. - // If batches were swapped among themselves, or if the last two dims are transposed, the underlying - // kernel handles it without needing to turn it into contiguous. - // let round_lhs = pad_round::(lhs, B_M, B_K); - // let lhs = match round_lhs { - // PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - // into_contiguous(tensor) - // } - // _ => round_lhs.into_tensor(), - // }; - // let round_rhs = pad_round::(rhs, B_K, B_N); - // let rhs = match round_rhs { - // PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - // into_contiguous(tensor) - // } - // _ => round_rhs.into_tensor(), - // }; - - // let rounded_output_shape = shape_out(&lhs, &rhs); - - // let rounded_output = empty_device( - // rhs.client.clone(), - // rhs.device.clone(), - // rounded_output_shape.clone(), - // ); - - // let workgroup = make_workgroup(&rounded_output_shape); - // let info_handle = make_info_handle(&lhs, &rhs, &rounded_output); - - // lhs.client.execute( - // Box::new(DynamicKernel::new(kernel, workgroup)), - // &[ - // &lhs.handle, - // &rhs.handle, - // &rounded_output.handle, - // &info_handle, - // ], - // ); - - // crop(rounded_output, output) -} - pub(crate) fn launch_options( lhs_shape: &Shape, rhs_shape: &Shape, diff --git a/crates/burn-jit/src/kernel/matmul/mod.rs b/crates/burn-jit/src/kernel/matmul/mod.rs index 7c187f71f4..9c405e060d 100644 --- a/crates/burn-jit/src/kernel/matmul/mod.rs +++ b/crates/burn-jit/src/kernel/matmul/mod.rs @@ -1,5 +1,7 @@ mod base; mod simple; +mod tiling2d; +mod tiling2d_padded; mod tune; /// Contains utilitary for matmul operation @@ -17,7 +19,5 @@ pub mod padding; #[cfg(not(feature = "export_tests"))] mod padding; -pub mod tiling2d; -pub mod tiling2d_padded; pub use tiling2d::*; pub use tiling2d_padded::*; diff --git a/crates/burn-jit/src/kernel/matmul/simple.rs b/crates/burn-jit/src/kernel/matmul/simple.rs index 0408f0414a..f106e1482d 100644 --- a/crates/burn-jit/src/kernel/matmul/simple.rs +++ b/crates/burn-jit/src/kernel/matmul/simple.rs @@ -6,15 +6,15 @@ use crate::{ dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, InputInfo, OutputInfo, WorkgroupLaunch, }, - compute::WorkGroup, element::JitElement, kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT}, tensor::JitTensor, Runtime, }; -use burn_tensor::Shape; use std::marker::PhantomData; +use super::launch_options; + #[derive(new, Debug)] struct MatmulEagerKernel { workgroup_size_x: usize, @@ -242,9 +242,8 @@ pub fn matmul_simple( &[ EagerHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), - EagerHandle::new(&out.handle, &out.strides, &out.shape.dims), ], - &[], + &[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)], None, kernel, WorkgroupLaunch::Custom(workgroup), @@ -253,24 +252,3 @@ pub fn matmul_simple( out } - -fn launch_options( - lhs_shape: &Shape, - rhs_shape: &Shape, - output_shape: &Shape, - workgroup_size_x: usize, - workgroup_size_y: usize, -) -> WorkGroup { - let num_rows = lhs_shape.dims[D - 2]; - let num_cols = rhs_shape.dims[D - 1]; - - // set number of workgroups - let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; - let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; - let mut num_iter = 1; - for i in 0..D - 2 { - num_iter *= output_shape.dims[i]; - } - - WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32) -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index 75f5af2559..c6af6d36e0 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -2,68 +2,36 @@ use burn_tensor::Element; use crate::{ codegen::{ - dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, - EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, + dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, + Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, - compute::{DynamicKernel, WorkGroup}, element::JitElement, gpu::{gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Scope, Variable}, - kernel::{ - into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, - }, + kernel::{into_contiguous, DynamicKernelSource, SourceTemplate}, tensor::JitTensor, Runtime, }; use std::marker::PhantomData; -use super::{ - base::{make_info_handle, B_K, B_M, B_N, WORKGROUP_SIZE}, - launch_options, -}; - -use burn_tensor::Shape; +use super::{launch_options, Tiling2dConfig}; #[derive(new, Debug)] -struct MatmulTiling2DUnpadded { +struct MatmulTiling2d { _elem: PhantomData, } -// impl DynamicKernelSource for MatmulTiling2DUnpadded { -// fn source(&self) -> SourceTemplate { -// MatmulTiling2DUnpaddedRaw::source() -// .register("b_m", B_M.to_string()) -// .register("b_n", B_N.to_string()) -// .register("b_k", B_K.to_string()) -// .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) -// .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) -// .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) -// .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) -// .register("workgroup_size_z", "1".to_string()) -// .register("elem", E::type_name()) -// .register("int", "i32") -// } - -// fn id(&self) -> String { -// std::format!("{:?}", self) -// } -// } - #[derive(new, Debug)] -struct MatmulTiling2DEagerKernel { - workgroup_size_x: usize, - workgroup_size_y: usize, - block_size_m: usize, - block_size_k: usize, - block_size_n: usize, +struct MatmulTiling2dEagerKernel { + config: Tiling2dConfig, _runtime: PhantomData, } -struct MatmulTiling2DShader { +struct MatmulTiling2dShader { variables: BinaryOperator, block_size: usize, } -impl MatmulTiling2DShader { +impl MatmulTiling2dShader { fn expand(self, scope: &mut Scope) { // Define out global variables. let local_idx = Variable::LocalInvocationIndex; @@ -185,7 +153,7 @@ impl MatmulTiling2DShader { } } -impl DynamicKernelSource for MatmulTiling2DEagerKernel { +impl DynamicKernelSource for MatmulTiling2dEagerKernel { fn source(&self) -> SourceTemplate { let mut scope = gpu::Scope::root(); let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into()); @@ -194,9 +162,9 @@ impl DynamicKernelSource for MatmulTiling2DEagerKernel { scope.write_global_custom(out); - MatmulTiling2DShader { + MatmulTiling2dShader { variables: gpu::BinaryOperator { lhs, rhs, out }, - block_size: self.workgroup_size_x, + block_size: self.config.grid_x, // TODO } .expand(&mut scope); @@ -219,8 +187,8 @@ impl DynamicKernelSource for MatmulTiling2DEagerKernel { }; let settings = CompilationSettings::default().workgroup_size(gpu::WorkgroupSize::new( - self.workgroup_size_x as u32, - self.workgroup_size_y as u32, + self.config.grid_x as u32, + self.config.grid_y as u32, 1, )); let shader = Compilation::new(info).compile(settings); @@ -230,13 +198,9 @@ impl DynamicKernelSource for MatmulTiling2DEagerKernel { fn id(&self) -> String { format!( - "{:?}x={}y={}b_m={}b_k={}b_n={}", + "{:?}config={:?}", core::any::TypeId::of::(), - self.workgroup_size_x, - self.workgroup_size_y, - self.block_size_m, - self.block_size_k, - self.block_size_n, + self.config, ) } } @@ -247,19 +211,9 @@ pub fn matmul_tiling_2d( lhs: JitTensor, rhs: JitTensor, out: JitTensor, - workgroup_size_x: usize, - workgroup_size_y: usize, - block_size_m: usize, - block_size_k: usize, - block_size_n: usize, + config: Tiling2dConfig, ) -> JitTensor { - let kernel = MatmulTiling2DEagerKernel::::new( - workgroup_size_x, - workgroup_size_y, - block_size_m, - block_size_k, - block_size_n, - ); + let kernel = MatmulTiling2dEagerKernel::::new(config.clone()); let client = lhs.client.clone(); let lhs = match lhs.batch_swapped_with_row_col() { @@ -281,8 +235,8 @@ pub fn matmul_tiling_2d( &lhs.shape, &rhs.shape, &out.shape, - workgroup_size_x, - workgroup_size_y, + config.grid_x, + config.grid_y, ))); out diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index 720847ffa9..1f854acc68 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -1,52 +1,272 @@ -use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE}; +use burn_tensor::Element; + use crate::{ + codegen::{ + dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, + Execution, InputInfo, OutputInfo, WorkgroupLaunch, + }, element::JitElement, - kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, + gpu::{gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Scope, Variable}, + kernel::{into_contiguous, DynamicKernelSource, SourceTemplate}, tensor::JitTensor, + Runtime, }; -use crate::{kernel_wgsl, Runtime}; use std::marker::PhantomData; -kernel_wgsl!( - MatmulTiling2Dvec4Raw, - "../../template/matmul/blocktiling_2d/vec4.wgsl" -); +use super::{ + launch_options, + padding::{crop, pad_round, PaddingOutput}, + shape_out, Tiling2dConfig, +}; #[derive(new, Debug)] -struct MatmulTiling2Dvec4 { +struct MatmulTiling2dPadded { _elem: PhantomData, } -impl DynamicKernelSource for MatmulTiling2Dvec4 { +#[derive(new, Debug)] +struct MatmulTiling2dPaddedEagerKernel { + config: Tiling2dConfig, + _runtime: PhantomData, +} + +struct MatmulTiling2dPaddedShader { + variables: BinaryOperator, + block_size: usize, +} + +impl MatmulTiling2dPaddedShader { + fn expand(self, scope: &mut Scope) { + // Define out global variables. + let local_idx = Variable::LocalInvocationIndex; + let batch = Variable::GlobalInvocationIdZ; + let rank = Variable::Rank; + let block_size: Variable = self.block_size.into(); + + // Extract tensor variables. + let lhs = self.variables.lhs; + let rhs = self.variables.rhs; + let out = self.variables.out; + + // Define where we have to work on the current matrix. + let tmp_index = scope.create_local(Elem::UInt); + let batch_dims = scope.create_local(Elem::UInt); + let row = scope.create_local(Elem::UInt); + let col = scope.create_local(Elem::UInt); + + // Row position. + gpu!(scope, tmp_index = local_idx / block_size); + gpu!(scope, row = block_size * Variable::WorkgroupIdX); + gpu!(scope, row = row + tmp_index); + + // Col position. + gpu!(scope, tmp_index = local_idx % block_size); + gpu!(scope, col = block_size * Variable::WorkgroupIdY); + gpu!(scope, col = col + tmp_index); + + // Batch position. + gpu!(scope, batch_dims = rank - 2u32); + + // Define the matrix size. + let n_rows = scope.create_local(Elem::UInt); + let n_cols = scope.create_local(Elem::UInt); + let k = scope.create_local(Elem::UInt); + + // Number of rows. + gpu!(scope, n_rows = shape(out, batch_dims)); + + // Number of cols. + gpu!(scope, tmp_index = batch_dims + 1u32); + gpu!(scope, n_cols = shape(out, tmp_index)); + + // The dimension that is going to be squashed. + gpu!(scope, k = shape(lhs, tmp_index)); + + // Check if there is some work to be done. + let should_stop = scope.create_local(Elem::Bool); + gpu!(scope, should_stop = row >= n_rows); + gpu!(scope, if (should_stop).then(|scope| { + scope.register(Branch::Return); + })); + + gpu!(scope, should_stop = col >= n_cols); + gpu!(scope, if (should_stop).then(|scope| { + scope.register(Branch::Return); + })); + + // Calculate the batch offset. + let offset_lhs = scope.zero(Elem::UInt); + let offset_rhs = scope.zero(Elem::UInt); + let offset_output = scope.create_local(Elem::UInt); + + // Batch offset for the output. + gpu!(scope, offset_output = n_rows * n_cols); + gpu!(scope, offset_output = offset_output * batch); + + // Batch offset for the lhs & rhs matrices. + IndexOffsetGlobalWithLayout { + tensors: vec![lhs, rhs], + indexes: vec![offset_lhs, offset_rhs], + layout: out, + index_ref: offset_output, + dim_start: 0u32.into(), + dim_end: batch_dims, + } + .expand(scope); + + // Calculate the dot product (row X col). + let sum = scope.create_local(out.item()); + + // Initialize the sum to zero. + let zero: Variable = 0f32.into(); + gpu!(scope, sum = zero); + + // Loop over the k dimension. + gpu!( + scope, + range(0u32, k).for_each(|i, scope| { + let lhs_index = scope.create_local(Elem::UInt); + let rhs_index = scope.create_local(Elem::UInt); + + let lhs_value = scope.create_local(lhs.item()); + let rhs_value = scope.create_local(rhs.item()); + let out_value = scope.create_local(out.item()); + + gpu!(scope, lhs_index = row * k); + gpu!(scope, lhs_index = lhs_index + i); + gpu!(scope, lhs_index = lhs_index + offset_lhs); + + gpu!(scope, rhs_index = i * n_cols); + gpu!(scope, rhs_index = rhs_index + col); + gpu!(scope, rhs_index = rhs_index + offset_rhs); + + gpu!(scope, lhs_value = lhs[lhs_index]); + gpu!(scope, rhs_value = rhs[rhs_index]); + + gpu!(scope, out_value = lhs_value * rhs_value); + gpu!(scope, sum += out_value); + }) + ); + + let out_index = scope.create_local(Elem::UInt); + + gpu!(scope, out_index = row * n_cols); + gpu!(scope, out_index += col); + gpu!(scope, out_index += offset_output); + gpu!(scope, out[out_index] = sum); + } +} + +impl DynamicKernelSource for MatmulTiling2dPaddedEagerKernel { fn source(&self) -> SourceTemplate { - MatmulTiling2Dvec4Raw::source() - .register("b_m", B_M.to_string()) - .register("b_n", B_N.to_string()) - .register("b_k", B_K.to_string()) - .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) - .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) - .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_z", "1".to_string()) - .register("elem", E::type_name()) - .register("int", "i32") + let mut scope = gpu::Scope::root(); + let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into()); + let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into()); + let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into()); + + scope.write_global_custom(out); + + MatmulTiling2dPaddedShader { + variables: gpu::BinaryOperator { lhs, rhs, out }, + block_size: self.config.grid_x, // TODO + } + .expand(&mut scope); + + let lhs = InputInfo::Array { + item: gpu::Elem::Float.into(), + visibility: gpu::Visibility::Read, + }; + let rhs = InputInfo::Array { + item: gpu::Elem::Float.into(), + visibility: gpu::Visibility::Read, + }; + let out = OutputInfo::Array { + item: gpu::Elem::Float.into(), + }; + + let info = CompilationInfo { + inputs: vec![lhs, rhs], + outputs: vec![out], + scope, + }; + + let settings = CompilationSettings::default().workgroup_size(gpu::WorkgroupSize::new( + self.config.grid_x as u32, + self.config.grid_y as u32, + 1, + )); + let shader = Compilation::new(info).compile(settings); + let shader = ::compile(shader); + SourceTemplate::new(shader.to_string()) } fn id(&self) -> String { - std::format!("{:?}", self) + format!( + "{:?}config={:?}", + core::any::TypeId::of::(), + self.config, + ) } } -pub fn matmul_tiling_2d_padded( +/// Matrix multiplication using tiling 2d algorithm with +/// vec4 primitive on both lhs and rhs, with no padding needed +pub fn matmul_tiling_2d_padded( lhs: JitTensor, rhs: JitTensor, out: JitTensor, - workgroup_size_x: usize, - workgroup_size_y: usize, - block_size_m: usize, - block_size_k: usize, - block_size_n: usize, + config: Tiling2dConfig, ) -> JitTensor { - let kernel = MatmulTiling2Dvec4::::new(); - matmul_tiling_2d_launch::(lhs, rhs, out, kernel) + let kernel = MatmulTiling2dPaddedEagerKernel::::new(config.clone()); + let client = lhs.client.clone(); + + // A tensor may need to be padded, in which case it will implicitly become contiguous + // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. + // If batches were swapped among themselves, or if the last two dims are transposed, the underlying + // kernel handles it without needing to turn it into contiguous. + let round_lhs = pad_round::(lhs, config.block_size_m, config.block_size_k); + let lhs = match round_lhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_lhs.into_tensor(), + }; + let round_rhs = pad_round::(rhs, config.block_size_k, config.block_size_n); + let rhs = match round_rhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_rhs.into_tensor(), + }; + + let rounded_output_shape = shape_out(&lhs, &rhs); + + let num_elems = rounded_output_shape.num_elements(); + let buffer = client.empty(num_elems * core::mem::size_of::()); + let rounded_output = JitTensor::new( + rhs.client.clone(), + rhs.device.clone(), + rounded_output_shape.clone(), + buffer, + ); + + Execution::start(kernel, client) + .inputs(&[ + EagerHandle::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + ]) + .outputs(&[EagerHandle::new( + &rounded_output.handle, + &rounded_output.strides, + &rounded_output.shape.dims, + )]) + .execute(WorkgroupLaunch::Custom(launch_options( + &lhs.shape, + &rhs.shape, + &out.shape, + config.grid_x, + config.grid_y, + ))); + + crop(rounded_output, out) } diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 71e77e32be..1fa9ab0a9e 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -4,7 +4,10 @@ use burn_tensor::{Element, ElementConversion}; use crate::{ compute::JitAutotuneKey, element::JitElement, - kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, + kernel::{ + matmul::{utils::init_matmul_output, Tiling2dConfig}, + prng::random_like_uniform, + }, ops::numeric::empty_device, tensor::JitTensor, Runtime, @@ -56,8 +59,8 @@ impl AutotuneOperationSet AutotuneOperationSet Box::new(SimpleMatmul::new(self.lhs, self.rhs, self.out)), 1 => Box::new(SimpleMatmul16x16::new(self.lhs, self.rhs, self.out)), - 2 => Box::new(Tiling2DMatmul::new(self.lhs, self.rhs, self.out)), - 3 => Box::new(Tiling2DMatmulPadded::new(self.lhs, self.rhs, self.out)), + 2 => Box::new(Tiling2dMatmul::new(self.lhs, self.rhs, self.out)), + 3 => Box::new(Tiling2dMatmulPadded::new(self.lhs, self.rhs, self.out)), _ => panic!("Fastest index is out of bound"), } } @@ -130,11 +133,11 @@ matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| { }); // Probably the fastest when fixed sizes. -matmul_tune_ops!(Tiling2DMatmulPadded, |lhs, rhs, out| { - crate::kernel::matmul::matmul_tiling_2d_padded(lhs, rhs, out, 16, 16, 64, 32, 64) +matmul_tune_ops!(Tiling2dMatmulPadded, |lhs, rhs, out| { + crate::kernel::matmul::matmul_tiling_2d_padded(lhs, rhs, out, Tiling2dConfig::default()) }); // Probably the fastest in the general case -matmul_tune_ops!(Tiling2DMatmul, |lhs, rhs, out| { - crate::kernel::matmul::matmul_tiling_2d(lhs, rhs, out, 16, 16, 64, 32, 64) +matmul_tune_ops!(Tiling2dMatmul, |lhs, rhs, out| { + crate::kernel::matmul::matmul_tiling_2d(lhs, rhs, out, Tiling2dConfig::default()) }); diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index d0034e4112..4268a6f573 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -187,6 +187,7 @@ where } pub(crate) fn batch_swapped_with_row_col(&self) -> bool { + println!("{:?}", self.strides); for d in 0..D - 2 { let stride = self.strides[d]; if stride < self.strides[D - 2] || stride < self.strides[D - 1] { diff --git a/crates/burn-jit/src/tests/matmul.rs b/crates/burn-jit/src/tests/matmul.rs index 0ec2a66bf6..e91e5cbcc2 100644 --- a/crates/burn-jit/src/tests/matmul.rs +++ b/crates/burn-jit/src/tests/matmul.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(matmul)] mod tests { use super::*; - use burn_jit::kernel::matmul::{matmul, MatmulStrategy}; + use burn_jit::kernel::matmul::{matmul, MatmulStrategy, Tiling2dConfig}; use burn_tensor::{Shape, Tensor}; mod simple { @@ -174,7 +174,7 @@ mod tests { let shape_lhs = [3, 2, 4, 4]; let shape_rhs = [3, 2, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dPadded, + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), swap, swap, shape_lhs, @@ -189,7 +189,7 @@ mod tests { let shape_lhs = [3, 2, 4, 4]; let shape_rhs = [3, 2, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dPadded, + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), swap_lhs, swap_rhs, shape_lhs, @@ -204,7 +204,7 @@ mod tests { let shape_lhs = [4, 4, 4, 4]; let shape_rhs = [4, 4, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dPadded, + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), swap_lhs, swap_rhs, shape_lhs, @@ -215,7 +215,11 @@ mod tests { fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { let shape_lhs = [batch_1, batch_2, m, k]; let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(MatmulStrategy::Tiling2dPadded, shape_lhs, shape_rhs); + same_as_reference( + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), + shape_lhs, + shape_rhs, + ); } } @@ -308,7 +312,7 @@ mod tests { let shape_lhs = [3, 2, 4, 4]; let shape_rhs = [3, 2, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2d, + MatmulStrategy::Tiling2d(Tiling2dConfig::default()), swap, swap, shape_lhs, @@ -323,7 +327,7 @@ mod tests { let shape_lhs = [3, 2, 4, 4]; let shape_rhs = [3, 2, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2d, + MatmulStrategy::Tiling2d((Tiling2dConfig::default())), swap_lhs, swap_rhs, shape_lhs, @@ -338,7 +342,7 @@ mod tests { let shape_lhs = [4, 4, 4, 4]; let shape_rhs = [4, 4, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2d, + MatmulStrategy::Tiling2d(Tiling2dConfig::default()), swap_lhs, swap_rhs, shape_lhs, @@ -349,7 +353,11 @@ mod tests { fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { let shape_lhs = [batch_1, batch_2, m, k]; let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(MatmulStrategy::Tiling2d, shape_lhs, shape_rhs); + same_as_reference( + MatmulStrategy::Tiling2d(Tiling2dConfig::default()), + shape_lhs, + shape_rhs, + ); } } diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index 3a2f6a91dd..f84be0c019 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -64,6 +64,10 @@ impl burn_jit::Compiler for WgslCompiler { fn elem_size(elem: gpu::Elem) -> usize { Self::compile_elem(elem).size() } + + fn max_shared_memory_size() -> usize { + 8192 + } } impl WgslCompiler { From 423d23d4e31c4098ee82e9b835069a5771274b10 Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 12 Mar 2024 15:53:16 -0400 Subject: [PATCH 04/25] support local arrays --- crates/burn-jit/src/codegen/compiler.rs | 1 + .../burn-jit/src/codegen/dialect/gpu/scope.rs | 26 ++- .../src/codegen/dialect/gpu/variable.rs | 3 + .../src/codegen/dialect/gpu/vectorization.rs | 6 + .../src/kernel/matmul/tiling2d_padded.rs | 169 ++++++++---------- crates/burn-wgpu/src/compiler/wgsl/base.rs | 6 + .../burn-wgpu/src/compiler/wgsl/compiler.rs | 12 ++ crates/burn-wgpu/src/compiler/wgsl/shader.rs | 49 +++-- 8 files changed, 164 insertions(+), 108 deletions(-) diff --git a/crates/burn-jit/src/codegen/compiler.rs b/crates/burn-jit/src/codegen/compiler.rs index d5614607e4..250fecfe92 100644 --- a/crates/burn-jit/src/codegen/compiler.rs +++ b/crates/burn-jit/src/codegen/compiler.rs @@ -22,5 +22,6 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug { fn compile(shader: gpu::ComputeShader) -> Self::Representation; /// The size of the given element in bytes. fn elem_size(elem: gpu::Elem) -> usize; + /// The maximal size of a shared memory fn max_shared_memory_size() -> usize; } diff --git a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs index dc70fafd09..47868acfac 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs @@ -20,7 +20,8 @@ pub struct Scope { pub depth: u8, pub operations: Vec, locals: Vec, - shared: Vec, + shared_memories: Vec, + local_arrays: Vec, reads_global: Vec<(Variable, ReadingStrategy, Variable)>, index_offset_with_output_layout_position: Vec, writes_global: Vec<(Variable, Variable)>, @@ -48,7 +49,8 @@ impl Scope { depth: 0, operations: Vec::new(), locals: Vec::new(), - shared: Vec::new(), + local_arrays: Vec::new(), + shared_memories: Vec::new(), reads_global: Vec::new(), index_offset_with_output_layout_position: Vec::new(), writes_global: Vec::new(), @@ -213,7 +215,8 @@ impl Scope { depth: self.depth + 1, operations: Vec::new(), locals: Vec::new(), - shared: Vec::new(), + shared_memories: Vec::new(), + local_arrays: Vec::new(), reads_global: Vec::new(), index_offset_with_output_layout_position: Vec::new(), writes_global: Vec::new(), @@ -308,7 +311,11 @@ impl Scope { } fn new_shared_index(&self) -> u16 { - self.shared.len() as u16 + self.shared_memories.len() as u16 + } + + fn new_local_array_index(&self) -> u16 { + self.local_arrays.len() as u16 } fn read_input_strategy( @@ -339,7 +346,16 @@ impl Scope { let item = item.into(); let index = self.new_shared_index(); let shared_memory = Variable::SharedMemory(index, item, shared_memory_size); - self.shared.push(shared_memory); + self.local_arrays.push(shared_memory); shared_memory } + + /// Create a local array of the given [item type](Item). + pub fn create_local_array>(&mut self, item: I, array_size: u32) -> Variable { + let item = item.into(); + let index = self.new_local_array_index(); + let local_array = Variable::LocalArray(index, item, self.depth, array_size); + self.local_arrays.push(local_array); + local_array + } } diff --git a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs index 7a49d5eea0..cd52ea4134 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs @@ -11,6 +11,7 @@ pub enum Variable { LocalScalar(u16, Elem, u8), ConstantScalar(f64, Elem), SharedMemory(u16, Item, u32), + LocalArray(u16, Item, u8, u32), Id, LocalInvocationIndex, LocalInvocationIdX, @@ -41,6 +42,7 @@ impl Variable { Variable::GlobalOutputArray(idx, _) => Some(*idx), Variable::ConstantScalar(_, _) => None, Variable::SharedMemory(idx, _, _) => Some(*idx), + Variable::LocalArray(idx, _, _, _) => Some(*idx), Variable::Id => None, Variable::LocalInvocationIndex => None, Variable::LocalInvocationIdX => None, @@ -70,6 +72,7 @@ impl Variable { Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem), Variable::ConstantScalar(_, elem) => Item::Scalar(*elem), Variable::SharedMemory(_, item, _) => *item, + Variable::LocalArray(_, item, _, _) => *item, Variable::Id => Item::Scalar(Elem::UInt), Variable::Rank => Item::Scalar(Elem::UInt), Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt), diff --git a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs index 3d06a7652b..2c59a1d1ea 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs @@ -129,6 +129,12 @@ impl Variable { item.vectorize(vectorize), item.vectorized_size(vectorize, *size), ), + Variable::LocalArray(index, item, name, size) => Variable::LocalArray( + *index, + item.vectorize(vectorize), + *name, + item.vectorized_size(vectorize, *size), + ), Variable::ConstantScalar(_, _) => *self, Variable::GlobalScalar(_, _) => *self, Variable::Id => *self, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index 1f854acc68..9b9e937301 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -32,78 +32,94 @@ struct MatmulTiling2dPaddedEagerKernel { struct MatmulTiling2dPaddedShader { variables: BinaryOperator, - block_size: usize, + config: Tiling2dConfig, } impl MatmulTiling2dPaddedShader { fn expand(self, scope: &mut Scope) { - // Define out global variables. - let local_idx = Variable::LocalInvocationIndex; - let batch = Variable::GlobalInvocationIdZ; - let rank = Variable::Rank; - let block_size: Variable = self.block_size.into(); + // Phase 1: Gather information: input, shader and offsets - // Extract tensor variables. + // Inputs let lhs = self.variables.lhs; let rhs = self.variables.rhs; let out = self.variables.out; - // Define where we have to work on the current matrix. - let tmp_index = scope.create_local(Elem::UInt); - let batch_dims = scope.create_local(Elem::UInt); - let row = scope.create_local(Elem::UInt); - let col = scope.create_local(Elem::UInt); - - // Row position. - gpu!(scope, tmp_index = local_idx / block_size); - gpu!(scope, row = block_size * Variable::WorkgroupIdX); - gpu!(scope, row = row + tmp_index); + // Config variables + let block_size_m: Variable = self.config.block_size_m.into(); + let block_size_n: Variable = self.config.block_size_n.into(); + let tile_size_m: Variable = self.config.tile_size_m.into(); + let tile_size_n: Variable = self.config.tile_size_n.into(); + let n_threads_per_row: Variable = + (((self.config.block_size_n - 1) / self.config.tile_size_n) + 1).into(); + let results_size = (self.config.tile_size_m * self.config.tile_size_n) as u32; - // Col position. - gpu!(scope, tmp_index = local_idx % block_size); - gpu!(scope, col = block_size * Variable::WorkgroupIdY); - gpu!(scope, col = col + tmp_index); - - // Batch position. - gpu!(scope, batch_dims = rank - 2u32); + // Shader info + let local_idx = Variable::LocalInvocationIndex; + let batch = Variable::GlobalInvocationIdZ; - // Define the matrix size. - let n_rows = scope.create_local(Elem::UInt); - let n_cols = scope.create_local(Elem::UInt); + // Shapes + let rank = Variable::Rank; + let penultimate_dim = scope.create_local(Elem::UInt); + gpu!(scope, penultimate_dim = rank - 1u32); + let m = scope.create_local(Elem::UInt); let k = scope.create_local(Elem::UInt); + let n = scope.create_local(Elem::UInt); + gpu!(scope, m = shape(lhs, penultimate_dim)); + gpu!(scope, k = shape(rhs, penultimate_dim)); + gpu!(scope, n = shape(rhs, rank)); + + // Strides + let lhs_stride_row = scope.create_local(Elem::UInt); + let lhs_stride_col = scope.create_local(Elem::UInt); + let rhs_stride_row = scope.create_local(Elem::UInt); + let rhs_stride_col = scope.create_local(Elem::UInt); + let out_stride_row = scope.create_local(Elem::UInt); + let out_stride_col = scope.create_local(Elem::UInt); + gpu!(scope, lhs_stride_row = stride(lhs, penultimate_dim)); + gpu!(scope, lhs_stride_col = stride(lhs, rank)); + gpu!(scope, rhs_stride_row = stride(rhs, penultimate_dim)); + gpu!(scope, rhs_stride_col = stride(rhs, rank)); + gpu!(scope, out_stride_row = stride(out, penultimate_dim)); + gpu!(scope, out_stride_col = stride(out, rank)); + + // Workgroup offset + let skip_row = scope.create_local(Elem::UInt); + let workgroup_id_x = Variable::WorkgroupIdX; + gpu!(scope, skip_row = workgroup_id_x); + gpu!(scope, skip_row *= block_size_m); + let skip_col = scope.create_local(Elem::UInt); + let workgroup_id_y = Variable::WorkgroupIdY; + gpu!(scope, skip_col = workgroup_id_y); + gpu!(scope, skip_col *= block_size_n); + + // Invocation offset + let thread_row = scope.create_local(Elem::UInt); + gpu!(scope, thread_row = local_idx / n_threads_per_row); + gpu!(scope, thread_row *= tile_size_m); + let thread_col = scope.create_local(Elem::UInt); + gpu!(scope, thread_col = local_idx % n_threads_per_row); + gpu!(scope, thread_col *= tile_size_n); + + // Row and col + let row = scope.create_local(Elem::UInt); + let col = scope.create_local(Elem::UInt); + gpu!(scope, row = skip_row + thread_row); + gpu!(scope, col = skip_col + thread_col); - // Number of rows. - gpu!(scope, n_rows = shape(out, batch_dims)); - - // Number of cols. - gpu!(scope, tmp_index = batch_dims + 1u32); - gpu!(scope, n_cols = shape(out, tmp_index)); - - // The dimension that is going to be squashed. - gpu!(scope, k = shape(lhs, tmp_index)); - - // Check if there is some work to be done. - let should_stop = scope.create_local(Elem::Bool); - gpu!(scope, should_stop = row >= n_rows); - gpu!(scope, if (should_stop).then(|scope| { - scope.register(Branch::Return); - })); - - gpu!(scope, should_stop = col >= n_cols); - gpu!(scope, if (should_stop).then(|scope| { - scope.register(Branch::Return); - })); - - // Calculate the batch offset. - let offset_lhs = scope.zero(Elem::UInt); - let offset_rhs = scope.zero(Elem::UInt); - let offset_output = scope.create_local(Elem::UInt); + // Calculate offset. + let offset_lhs = scope.create_local(Elem::UInt); + let offset_rhs = scope.create_local(Elem::UInt); + gpu!(scope, offset_lhs = skip_row * lhs_stride_row); + gpu!(scope, offset_rhs = skip_col * rhs_stride_col); // Batch offset for the output. - gpu!(scope, offset_output = n_rows * n_cols); + let offset_output = scope.create_local(Elem::UInt); + let batch_dims = scope.create_local(Elem::UInt); + gpu!(scope, offset_output = m * n); gpu!(scope, offset_output = offset_output * batch); // Batch offset for the lhs & rhs matrices. + gpu!(scope, batch_dims = rank - 2u32); IndexOffsetGlobalWithLayout { tensors: vec![lhs, rhs], indexes: vec![offset_lhs, offset_rhs], @@ -114,46 +130,13 @@ impl MatmulTiling2dPaddedShader { } .expand(scope); - // Calculate the dot product (row X col). - let sum = scope.create_local(out.item()); - - // Initialize the sum to zero. - let zero: Variable = 0f32.into(); - gpu!(scope, sum = zero); - - // Loop over the k dimension. - gpu!( - scope, - range(0u32, k).for_each(|i, scope| { - let lhs_index = scope.create_local(Elem::UInt); - let rhs_index = scope.create_local(Elem::UInt); - - let lhs_value = scope.create_local(lhs.item()); - let rhs_value = scope.create_local(rhs.item()); - let out_value = scope.create_local(out.item()); - - gpu!(scope, lhs_index = row * k); - gpu!(scope, lhs_index = lhs_index + i); - gpu!(scope, lhs_index = lhs_index + offset_lhs); - - gpu!(scope, rhs_index = i * n_cols); - gpu!(scope, rhs_index = rhs_index + col); - gpu!(scope, rhs_index = rhs_index + offset_rhs); - - gpu!(scope, lhs_value = lhs[lhs_index]); - gpu!(scope, rhs_value = rhs[rhs_index]); - - gpu!(scope, out_value = lhs_value * rhs_value); - gpu!(scope, sum += out_value); - }) - ); + // Phase 2: Loop over k for loading and computing - let out_index = scope.create_local(Elem::UInt); + let results = scope.create_local_array(lhs.item().elem(), results_size); + let tmp = scope.create_local(lhs.item()); + gpu!(scope, tmp = results[offset_lhs]); - gpu!(scope, out_index = row * n_cols); - gpu!(scope, out_index += col); - gpu!(scope, out_index += offset_output); - gpu!(scope, out[out_index] = sum); + // Phase 3: Write to output } } @@ -168,7 +151,7 @@ impl DynamicKernelSource for MatmulTiling2dPaddedEagerKernel { MatmulTiling2dPaddedShader { variables: gpu::BinaryOperator { lhs, rhs, out }, - block_size: self.config.grid_x, // TODO + config: self.config.clone(), } .expand(&mut scope); diff --git a/crates/burn-wgpu/src/compiler/wgsl/base.rs b/crates/burn-wgpu/src/compiler/wgsl/base.rs index 45ad810d32..a268b4d166 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/base.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/base.rs @@ -18,6 +18,7 @@ pub enum Variable { scope_depth: u8, }, SharedMemory(u16, Item, u32), + LocalArray(u16, Item, u8, u32), Id, LocalInvocationIndex, LocalInvocationIdX, @@ -79,6 +80,7 @@ impl Variable { Variable::GlobalInputArray(_, _) => false, Variable::GlobalOutputArray(_, _) => false, Variable::SharedMemory(_, _, _) => false, + Variable::LocalArray(_, _, _, _) => false, Variable::Local { index: _, item: _, @@ -110,6 +112,7 @@ impl Variable { Self::GlobalInputArray(_, e) => *e, Self::GlobalOutputArray(_, e) => *e, Self::SharedMemory(_, e, _) => *e, + Self::LocalArray(_, e, _, _) => *e, Self::Local { index: _, item, @@ -222,6 +225,9 @@ impl Display for Variable { Variable::SharedMemory(number, _, _) => { f.write_fmt(format_args!("shared_memory_{number}")) } + Variable::LocalArray(number, _, scope_depth, _) => { + f.write_fmt(format_args!("a_{scope_depth}_{number}")) + } Variable::Id => f.write_str("id"), Variable::LocalInvocationIndex => f.write_str("local_idx"), Variable::LocalInvocationIdX => f.write_str("local_invocation_id.x"), diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index f84be0c019..4532c8bac2 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -1,3 +1,4 @@ +use super::LocalArray; use super::{shader::ComputeShader, Item, SharedMemory}; use crate::compiler::wgsl; use crate::{FloatElement, IntElement}; @@ -19,6 +20,7 @@ pub struct WgslCompiler { shape: bool, num_workgroups: bool, shared_memories: Vec, + local_arrays: Vec, _float: PhantomData, _int: PhantomData, } @@ -44,6 +46,7 @@ impl Default for WgslCompiler { shape: false, num_workgroups: false, shared_memories: Vec::default(), + local_arrays: Vec::default(), _float: PhantomData, _int: PhantomData, } @@ -102,6 +105,7 @@ impl WgslCompiler { .map(|(name, binding)| (name, Self::compile_binding(binding))) .collect(), shared_memories: self.shared_memories.clone(), + local_arrays: self.local_arrays.clone(), workgroup_size: value.workgroup_size, global_invocation_id: self.global_invocation_id || self.id, local_invocation_index: self.local_invocation_index, @@ -163,6 +167,14 @@ impl WgslCompiler { } wgsl::Variable::SharedMemory(index, item, size) } + gpu::Variable::LocalArray(index, item, scope_depth, size) => { + let item = Self::compile_item(item); + if !self.local_arrays.iter().any(|s| s.index == index) { + self.local_arrays + .push(LocalArray::new(index, item, scope_depth, size)); + } + wgsl::Variable::LocalArray(index, item, scope_depth, size) + } gpu::Variable::Id => { self.id = true; wgsl::Variable::Id diff --git a/crates/burn-wgpu/src/compiler/wgsl/shader.rs b/crates/burn-wgpu/src/compiler/wgsl/shader.rs index 6036584b0d..79d890c4d4 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/shader.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/shader.rs @@ -5,7 +5,6 @@ use std::fmt::Display; #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum Location { Storage, - #[allow(dead_code)] Workgroup, } @@ -42,12 +41,32 @@ impl SharedMemory { } } +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct LocalArray { + pub index: u16, + item: Item, + name: u8, + size: u32, +} + +impl LocalArray { + pub fn new(index: u16, item: Item, name: u8, size: u32) -> Self { + Self { + index, + item, + name, + size, + } + } +} + #[derive(Debug, Clone)] pub struct ComputeShader { pub inputs: Vec, pub outputs: Vec, pub named: Vec<(String, Binding)>, pub shared_memories: Vec, + pub local_arrays: Vec, pub workgroup_size: WorkgroupSize, pub global_invocation_id: bool, pub local_invocation_index: bool, @@ -72,10 +91,10 @@ impl Display for ComputeShader { )?; } - for shared_memory in self.shared_memories.iter() { + for array in self.shared_memories.iter() { f.write_fmt(format_args!( - "var<{}> shared_memory_{}: array<{}, {}>;\n\n", - shared_memory.location, shared_memory.index, shared_memory.item, shared_memory.size + "var<{}> array_{}: array<{}, {}>;\n\n", + array.location, array.index, array.item, array.size ))?; } @@ -115,12 +134,22 @@ fn main( f.write_str(" @builtin(workgroup_id) workgroup_id: vec3,\n")?; } - f.write_fmt(format_args!( - ") {{ - {} -}}", - self.body - ))?; + // Open body + f.write_fmt(format_args!(") {{"))?; + + // Local arrays + for array in self.local_arrays.iter() { + f.write_fmt(format_args!( + "var a_{}_{}: array<{}, {}>;\n\n", + array.name, array.index, array.item, array.size + ))?; + } + + // Body + f.write_fmt(format_args!("{}", self.body))?; + + // Close body + f.write_fmt(format_args!("}}"))?; for extension in self.extensions.iter() { f.write_fmt(format_args!("{extension}\n\n"))?; From fa222a0477b92fd1e3e15d25a40f7c9fec13da01 Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 12 Mar 2024 16:12:15 -0400 Subject: [PATCH 05/25] advancing tiling2d --- .../src/kernel/matmul/tiling2d_padded.rs | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index 9b9e937301..37e33c5c2f 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -6,7 +6,7 @@ use crate::{ Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, element::JitElement, - gpu::{gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Scope, Variable}, + gpu::{gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Item, Scope, Variable}, kernel::{into_contiguous, DynamicKernelSource, SourceTemplate}, tensor::JitTensor, Runtime, @@ -46,6 +46,7 @@ impl MatmulTiling2dPaddedShader { // Config variables let block_size_m: Variable = self.config.block_size_m.into(); + let block_size_k: Variable = self.config.block_size_k.into(); let block_size_n: Variable = self.config.block_size_n.into(); let tile_size_m: Variable = self.config.tile_size_m.into(); let tile_size_n: Variable = self.config.tile_size_n.into(); @@ -61,12 +62,12 @@ impl MatmulTiling2dPaddedShader { let rank = Variable::Rank; let penultimate_dim = scope.create_local(Elem::UInt); gpu!(scope, penultimate_dim = rank - 1u32); - let m = scope.create_local(Elem::UInt); - let k = scope.create_local(Elem::UInt); - let n = scope.create_local(Elem::UInt); - gpu!(scope, m = shape(lhs, penultimate_dim)); - gpu!(scope, k = shape(rhs, penultimate_dim)); - gpu!(scope, n = shape(rhs, rank)); + let M = scope.create_local(Elem::UInt); + let K = scope.create_local(Elem::UInt); + let N = scope.create_local(Elem::UInt); + gpu!(scope, M = shape(lhs, penultimate_dim)); + gpu!(scope, K = shape(rhs, penultimate_dim)); + gpu!(scope, N = shape(rhs, rank)); // Strides let lhs_stride_row = scope.create_local(Elem::UInt); @@ -115,7 +116,7 @@ impl MatmulTiling2dPaddedShader { // Batch offset for the output. let offset_output = scope.create_local(Elem::UInt); let batch_dims = scope.create_local(Elem::UInt); - gpu!(scope, offset_output = m * n); + gpu!(scope, offset_output = M * N); gpu!(scope, offset_output = offset_output * batch); // Batch offset for the lhs & rhs matrices. @@ -133,8 +134,20 @@ impl MatmulTiling2dPaddedShader { // Phase 2: Loop over k for loading and computing let results = scope.create_local_array(lhs.item().elem(), results_size); - let tmp = scope.create_local(lhs.item()); - gpu!(scope, tmp = results[offset_lhs]); + let register_m = scope.create_local(Item::Vec4(lhs.item().elem())); + let register_n = scope.create_local(Item::Vec4(lhs.item().elem())); + + let n_loops = scope.create_local(Elem::UInt); + gpu!(scope, n_loops = K / block_size_k); // assumes padding, otherwise ceil + gpu!( + scope, + range(0, n_loops).for_each(|i, scope| { + let k = scope.create_local(Elem::UInt); + gpu!(scope, k = i * block_size_k); + + // HERE + }) + ) // Phase 3: Write to output } From 4419f453496aaa68f220c77663ece7225cdf5762 Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 13 Mar 2024 09:09:17 -0400 Subject: [PATCH 06/25] advancing tiling2d --- .../src/kernel/matmul/tiling2d_padded.rs | 60 ++++++++++++++++++- crates/burn-wgpu/src/compiler/wgsl/shader.rs | 2 +- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index 37e33c5c2f..a6496a72e9 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -136,16 +136,72 @@ impl MatmulTiling2dPaddedShader { let results = scope.create_local_array(lhs.item().elem(), results_size); let register_m = scope.create_local(Item::Vec4(lhs.item().elem())); let register_n = scope.create_local(Item::Vec4(lhs.item().elem())); + let shared_lhs = scope.create_shared( + Item::Vec4(lhs.item().elem()), + self.config.block_size_m as u32 * self.config.block_size_k as u32 / 4u32, + ); + let shared_rhs = scope.create_shared( + Item::Vec4(rhs.item().elem()), + self.config.block_size_k as u32 * self.config.block_size_n as u32 / 4u32, + ); let n_loops = scope.create_local(Elem::UInt); gpu!(scope, n_loops = K / block_size_k); // assumes padding, otherwise ceil gpu!( scope, - range(0, n_loops).for_each(|i, scope| { + range(0u32, n_loops).for_each(|i, scope| { + // Equivalent of looping from 0 to K with steps block_size_k let k = scope.create_local(Elem::UInt); gpu!(scope, k = i * block_size_k); - // HERE + // Phase 2.1: Load to shared memory + + // LHS + for j in 0u32..4u32 { + let current_col = scope.create_local(Elem::UInt); + gpu!(scope, current_col = thread_col + j); + + let aligned_with_shared_memory = scope.create_local(Elem::Bool); + gpu!( + scope, + aligned_with_shared_memory = current_col < block_size_k + ); + + // TODO if current_col >= B_K then we could break and not try other j + gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ + let lhs_sm_position = scope.create_local(Elem::UInt); + gpu!(scope, lhs_sm_position = thread_row / 4u32); + gpu!(scope, lhs_sm_position *= block_size_k); + gpu!(scope, lhs_sm_position += current_col); + + let lhs_position_0 = scope.create_local(Elem::UInt); + gpu!(scope, lhs_position_0 = k + current_col); + gpu!(scope, lhs_position_0 *= lhs_stride_col); + let tmp = scope.create_local(Elem::UInt); + gpu!(scope, tmp = thread_row * lhs_stride_row); + gpu!(scope, lhs_position_0 += tmp); + gpu!(scope, lhs_position_0 += offset_lhs); + let lhs_position_1 = scope.create_local(Elem::UInt); + let lhs_position_2 = scope.create_local(Elem::UInt); + let lhs_position_3 = scope.create_local(Elem::UInt); + gpu!(scope, lhs_position_1 = lhs_position_0 + lhs_stride_row); + gpu!(scope, lhs_position_2 = lhs_position_1 + lhs_stride_row); + gpu!(scope, lhs_position_3 = lhs_position_2 + lhs_stride_row); + + let lhs_0 = scope.create_local(lhs.item().elem()); + let lhs_1 = scope.create_local(lhs.item().elem()); + let lhs_2 = scope.create_local(lhs.item().elem()); + let lhs_3 = scope.create_local(lhs.item().elem()); + gpu!(scope, lhs_0 = lhs[lhs_position_0]); + gpu!(scope, lhs_1 = lhs[lhs_position_1]); + gpu!(scope, lhs_2 = lhs[lhs_position_2]); + gpu!(scope, lhs_3 = lhs[lhs_position_3]); + + let lhs_vec4 = scope.create_local(shared_lhs.item()); + // gpu!(scope, lhs_vec4 = [lhs_0, lhs_1, lhs_2, lhs_3]); + gpu!(scope, shared_lhs[lhs_sm_position] = lhs_vec4); + })); + } }) ) diff --git a/crates/burn-wgpu/src/compiler/wgsl/shader.rs b/crates/burn-wgpu/src/compiler/wgsl/shader.rs index 79d890c4d4..8229ac4510 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/shader.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/shader.rs @@ -93,7 +93,7 @@ impl Display for ComputeShader { for array in self.shared_memories.iter() { f.write_fmt(format_args!( - "var<{}> array_{}: array<{}, {}>;\n\n", + "var<{}> shared_memory_{}: array<{}, {}>;\n\n", array.location, array.index, array.item, array.size ))?; } From 9b6f2d19ebdd6e7449804f4d41a719bbfeea9a8d Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 13 Mar 2024 10:57:38 -0400 Subject: [PATCH 07/25] advancing tiling2d --- .../src/codegen/dialect/gpu/macros.rs | 6 + .../src/codegen/dialect/gpu/operation.rs | 11 ++ .../burn-jit/src/codegen/dialect/gpu/scope.rs | 2 +- .../src/codegen/dialect/gpu/vectorization.rs | 18 +- crates/burn-jit/src/fusion/tracing/builder.rs | 7 + .../src/kernel/matmul/tiling2d_padded.rs | 173 +++++++++++++----- .../burn-wgpu/src/compiler/wgsl/compiler.rs | 7 + .../src/compiler/wgsl/instructions.rs | 10 + 8 files changed, 185 insertions(+), 49 deletions(-) diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index 8438d5753d..eba8088341 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -283,6 +283,12 @@ macro_rules! gpu { gpu!(unary $input, $out) )); }; + // out = vec4(a, b, c, d) + ($scope:expr, $out:ident = vec4($a:ident,$b:ident,$c:ident,$d:ident)) => { + $scope.register($crate::codegen::dialect::gpu::Operator::AssignVec4( + $crate::codegen::dialect::gpu::AssignVec4Operator{a:$a,b:$b,c:$c,d:$d,out:$out} + )); + }; // out = input ($scope:expr, $out:ident = $input:ident) => { gpu!($scope, $out = cast($input)) diff --git a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs index ddb956c525..bad960a1b8 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs @@ -57,6 +57,7 @@ pub enum Operator { BitwiseXor(BinaryOperator), ShiftLeft(BinaryOperator), ShiftRight(BinaryOperator), + AssignVec4(AssignVec4Operator), } /// All metadata that can be access in a shader. @@ -105,6 +106,16 @@ pub struct ClampOperator { pub out: Variable, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[allow(missing_docs)] +pub struct AssignVec4Operator { + pub a: Variable, + pub b: Variable, + pub c: Variable, + pub d: Variable, + pub out: Variable, +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[allow(missing_docs)] pub struct ReadGlobalOperator { diff --git a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs index 47868acfac..585c59a601 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs @@ -346,7 +346,7 @@ impl Scope { let item = item.into(); let index = self.new_shared_index(); let shared_memory = Variable::SharedMemory(index, item, shared_memory_size); - self.local_arrays.push(shared_memory); + self.shared_memories.push(shared_memory); shared_memory } diff --git a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs index 2c59a1d1ea..936358cf45 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs @@ -1,4 +1,7 @@ -use super::{BinaryOperator, ClampOperator, Item, Operation, Operator, UnaryOperator, Variable}; +use super::{ + AssignVec4Operator, BinaryOperator, ClampOperator, Item, Operation, Operator, UnaryOperator, + Variable, +}; /// Define a vectorization scheme. #[allow(dead_code)] @@ -78,6 +81,7 @@ impl Operator { Operator::BitwiseXor(op) => Operator::BitwiseXor(op.vectorize(vectorization)), Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)), Operator::ShiftRight(op) => Operator::ShiftRight(op.vectorize(vectorization)), + Operator::AssignVec4(op) => Operator::AssignVec4(op.vectorize(vectorization)), } } } @@ -112,6 +116,18 @@ impl ClampOperator { } } +impl AssignVec4Operator { + pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { + Self { + a: self.a, + b: self.b, + c: self.c, + d: self.d, + out: self.out.vectorize(vectorization), + } + } +} + impl Variable { pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Self { match self { diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index e56ec520fb..0172b83ffa 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -346,6 +346,13 @@ impl TraceBuilder { &mut local_tensor_ids_input, &mut local_tensor_ids_output, ), + gpu::Operator::AssignVec4(op) => { + mark(&op.a, &mut local_tensor_ids_input); + mark(&op.b, &mut local_tensor_ids_input); + mark(&op.c, &mut local_tensor_ids_input); + mark(&op.d, &mut local_tensor_ids_input); + mark(&op.out, &mut local_tensor_ids_output); + } }, Operation::Procedure(proc) => { match proc { diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index a6496a72e9..fe6ce1dac7 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -6,7 +6,10 @@ use crate::{ Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, element::JitElement, - gpu::{gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Item, Scope, Variable}, + gpu::{ + gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Item, Scope, + Synchronization, Variable, + }, kernel::{into_contiguous, DynamicKernelSource, SourceTemplate}, tensor::JitTensor, Runtime, @@ -157,56 +160,132 @@ impl MatmulTiling2dPaddedShader { // Phase 2.1: Load to shared memory // LHS - for j in 0u32..4u32 { - let current_col = scope.create_local(Elem::UInt); - gpu!(scope, current_col = thread_col + j); - - let aligned_with_shared_memory = scope.create_local(Elem::Bool); - gpu!( - scope, - aligned_with_shared_memory = current_col < block_size_k - ); - - // TODO if current_col >= B_K then we could break and not try other j - gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ - let lhs_sm_position = scope.create_local(Elem::UInt); - gpu!(scope, lhs_sm_position = thread_row / 4u32); - gpu!(scope, lhs_sm_position *= block_size_k); - gpu!(scope, lhs_sm_position += current_col); - - let lhs_position_0 = scope.create_local(Elem::UInt); - gpu!(scope, lhs_position_0 = k + current_col); - gpu!(scope, lhs_position_0 *= lhs_stride_col); - let tmp = scope.create_local(Elem::UInt); - gpu!(scope, tmp = thread_row * lhs_stride_row); - gpu!(scope, lhs_position_0 += tmp); - gpu!(scope, lhs_position_0 += offset_lhs); - let lhs_position_1 = scope.create_local(Elem::UInt); - let lhs_position_2 = scope.create_local(Elem::UInt); - let lhs_position_3 = scope.create_local(Elem::UInt); - gpu!(scope, lhs_position_1 = lhs_position_0 + lhs_stride_row); - gpu!(scope, lhs_position_2 = lhs_position_1 + lhs_stride_row); - gpu!(scope, lhs_position_3 = lhs_position_2 + lhs_stride_row); - - let lhs_0 = scope.create_local(lhs.item().elem()); - let lhs_1 = scope.create_local(lhs.item().elem()); - let lhs_2 = scope.create_local(lhs.item().elem()); - let lhs_3 = scope.create_local(lhs.item().elem()); - gpu!(scope, lhs_0 = lhs[lhs_position_0]); - gpu!(scope, lhs_1 = lhs[lhs_position_1]); - gpu!(scope, lhs_2 = lhs[lhs_position_2]); - gpu!(scope, lhs_3 = lhs[lhs_position_3]); - - let lhs_vec4 = scope.create_local(shared_lhs.item()); - // gpu!(scope, lhs_vec4 = [lhs_0, lhs_1, lhs_2, lhs_3]); - gpu!(scope, shared_lhs[lhs_sm_position] = lhs_vec4); - })); - } + load_shared_memory( + scope, + k, + block_size_k, + block_size_n, + thread_col, + thread_row, + lhs_stride_col, + lhs_stride_row, + lhs, + offset_lhs, + shared_lhs, + true, + ); + + // RHS + load_shared_memory( + scope, + k, + block_size_k, + block_size_n, + thread_row, + thread_col, + rhs_stride_row, + rhs_stride_col, + rhs, + offset_rhs, + shared_rhs, + false, + ); + + scope.register(Synchronization::WorkgroupBarrier); + + // Phase 2.2: Compute intermediate results + + computation_loop(); + + scope.register(Synchronization::WorkgroupBarrier); }) - ) + ); // Phase 3: Write to output } + + fn load_shared_memory( + scope: &mut Scope, + k: Variable, + block_size_k: Variable, + block_size_n: Variable, + thread_idx_1: Variable, + thread_idx_2: Variable, + stride_1: Variable, + stride_2: Variable, + input: Variable, + input_offset: Variable, + shared_memory: Variable, + is_lhs: bool, + ) { + for j in 0u32..4u32 { + let current_col = scope.create_local(Elem::UInt); + gpu!(scope, current_col = thread_idx_1 + j); + + let aligned_with_shared_memory = scope.create_local(Elem::Bool); + gpu!( + scope, + aligned_with_shared_memory = current_col < block_size_k + ); + + gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ + let lhs_sm_position = scope.create_local(Elem::UInt); + if is_lhs { + gpu!(scope, lhs_sm_position = thread_idx_2 / 4u32); + gpu!(scope, lhs_sm_position *= block_size_k); + gpu!(scope, lhs_sm_position += current_col); + } else { + gpu!(scope, lhs_sm_position = current_col * block_size_n); + gpu!(scope, lhs_sm_position += thread_idx_2); + gpu!(scope, lhs_sm_position = lhs_sm_position / 4u32); + } + + let lhs_position_0 = scope.create_local(Elem::UInt); + gpu!(scope, lhs_position_0 = k + current_col); + gpu!(scope, lhs_position_0 *= stride_1); + let tmp = scope.create_local(Elem::UInt); + gpu!(scope, tmp = thread_idx_2 * stride_2); + gpu!(scope, lhs_position_0 += tmp); + gpu!(scope, lhs_position_0 += input_offset); + let lhs_position_1 = scope.create_local(Elem::UInt); + let lhs_position_2 = scope.create_local(Elem::UInt); + let lhs_position_3 = scope.create_local(Elem::UInt); + gpu!(scope, lhs_position_1 = lhs_position_0 + stride_2); + gpu!(scope, lhs_position_2 = lhs_position_1 + stride_2); + gpu!(scope, lhs_position_3 = lhs_position_2 + stride_2); + + let lhs_0 = scope.create_local(input.item().elem()); + let lhs_1 = scope.create_local(input.item().elem()); + let lhs_2 = scope.create_local(input.item().elem()); + let lhs_3 = scope.create_local(input.item().elem()); + gpu!(scope, lhs_0 = input[lhs_position_0]); + gpu!(scope, lhs_1 = input[lhs_position_1]); + gpu!(scope, lhs_2 = input[lhs_position_2]); + gpu!(scope, lhs_3 = input[lhs_position_3]); + + let lhs_vec4 = scope.create_local(shared_memory.item()); + gpu!(scope, lhs_vec4 = vec4(lhs_0, lhs_1, lhs_2, lhs_3)); + gpu!(scope, shared_memory[lhs_sm_position] = lhs_vec4); + + }).else(|scope|{ + scope.register(Branch::Break); // TODO test if faster, else remove + })); + } + } + + fn computation_loop( + scope: &mut Scope, + block_size_k: Variable, // needed in computation, but also use attribute for unrolling + thread_row: Variable, + shared_lhs: Variable, + register_m: Variable, + block_size_n: Variable, + thread_col: Variable, + shared_rhs: Variable, + register_n: Variable, // TM/TN use attribute for unrolling + results: Variable, + ) { + } } impl DynamicKernelSource for MatmulTiling2dPaddedEagerKernel { diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index 4532c8bac2..cb1d78e85f 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -543,6 +543,13 @@ impl WgslCompiler { rhs: self.compile_variable(op.rhs), out: self.compile_variable(op.out), }, + gpu::Operator::AssignVec4(op) => wgsl::Instruction::AssignVec4 { + a: self.compile_variable(op.a), + b: self.compile_variable(op.b), + c: self.compile_variable(op.c), + d: self.compile_variable(op.d), + out: self.compile_variable(op.out), + }, } } diff --git a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs index 7b410b6c06..7235def106 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs @@ -205,6 +205,13 @@ pub enum Instruction { rhs: Variable, out: Variable, }, + AssignVec4 { + a: Variable, + b: Variable, + c: Variable, + d: Variable, + out: Variable, + }, } impl Display for Instruction { @@ -455,6 +462,9 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{ Instruction::ShiftRight { lhs, rhs, out } => { f.write_fmt(format_args!("{out} = {lhs} >> {rhs};\n")) } + Instruction::AssignVec4 { a, b, c, d, out } => { + f.write_fmt(format_args!("{out} = vec4({a}, {b}, {c}, {d});\n")) + } } } } From feddabcdfa630a604c7b19ff42a5edbfb95be75a Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 13 Mar 2024 12:27:25 -0400 Subject: [PATCH 08/25] tiling2d finished but buggy --- .../src/kernel/matmul/tiling2d_padded.rs | 123 ++++++++++++++---- 1 file changed, 100 insertions(+), 23 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index fe6ce1dac7..02e0bc6639 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -136,15 +136,16 @@ impl MatmulTiling2dPaddedShader { // Phase 2: Loop over k for loading and computing - let results = scope.create_local_array(lhs.item().elem(), results_size); - let register_m = scope.create_local(Item::Vec4(lhs.item().elem())); - let register_n = scope.create_local(Item::Vec4(lhs.item().elem())); + let elem = lhs.item().elem(); + let results = scope.create_local_array(elem, results_size); + let register_m = scope.create_local(Item::Vec4(elem)); + let register_n = scope.create_local(Item::Vec4(elem)); let shared_lhs = scope.create_shared( - Item::Vec4(lhs.item().elem()), + Item::Vec4(elem), self.config.block_size_m as u32 * self.config.block_size_k as u32 / 4u32, ); let shared_rhs = scope.create_shared( - Item::Vec4(rhs.item().elem()), + Item::Vec4(elem), self.config.block_size_k as u32 * self.config.block_size_n as u32 / 4u32, ); @@ -160,11 +161,9 @@ impl MatmulTiling2dPaddedShader { // Phase 2.1: Load to shared memory // LHS - load_shared_memory( + self.load_shared_memory( scope, k, - block_size_k, - block_size_n, thread_col, thread_row, lhs_stride_col, @@ -176,11 +175,9 @@ impl MatmulTiling2dPaddedShader { ); // RHS - load_shared_memory( + self.load_shared_memory( scope, k, - block_size_k, - block_size_n, thread_row, thread_col, rhs_stride_row, @@ -195,20 +192,50 @@ impl MatmulTiling2dPaddedShader { // Phase 2.2: Compute intermediate results - computation_loop(); + self.computation_loop( + scope, thread_col, thread_row, shared_lhs, shared_rhs, register_m, register_n, + results, + ); scope.register(Synchronization::WorkgroupBarrier); }) ); // Phase 3: Write to output + for res_idx_m in 0..self.config.tile_size_m { + for res_idx_n in 0..self.config.tile_size_n { + let results_position = scope.create_local(Elem::UInt); + gpu!( + scope, + results_position = res_idx_m * self.config.tile_size_n + ); + gpu!(scope, results_position += res_idx_n); + + let result = scope.create_local(elem); + gpu!(scope, result = results[results_position]); + + let output_position = scope.create_local(Elem::UInt); + let output_position_tmp1 = scope.create_local(Elem::UInt); + let output_position_tmp2 = scope.create_local(Elem::UInt); + gpu!(scope, output_position_tmp1 = row + res_idx_m); + gpu!(scope, output_position_tmp1 *= out_stride_row); + gpu!(scope, output_position_tmp2 = col + res_idx_n); + gpu!(scope, output_position_tmp2 *= out_stride_col); + gpu!( + scope, + output_position = output_position_tmp1 + output_position_tmp2 + ); + gpu!(scope, output_position += offset_output); + + gpu!(scope, out[output_position] = result); + } + } } fn load_shared_memory( + &self, scope: &mut Scope, k: Variable, - block_size_k: Variable, - block_size_n: Variable, thread_idx_1: Variable, thread_idx_2: Variable, stride_1: Variable, @@ -218,6 +245,10 @@ impl MatmulTiling2dPaddedShader { shared_memory: Variable, is_lhs: bool, ) { + let block_size_k: Variable = self.config.block_size_k.into(); + let block_size_n: Variable = self.config.block_size_n.into(); + let elem = input.item().elem(); + for j in 0u32..4u32 { let current_col = scope.create_local(Elem::UInt); gpu!(scope, current_col = thread_idx_1 + j); @@ -254,10 +285,10 @@ impl MatmulTiling2dPaddedShader { gpu!(scope, lhs_position_2 = lhs_position_1 + stride_2); gpu!(scope, lhs_position_3 = lhs_position_2 + stride_2); - let lhs_0 = scope.create_local(input.item().elem()); - let lhs_1 = scope.create_local(input.item().elem()); - let lhs_2 = scope.create_local(input.item().elem()); - let lhs_3 = scope.create_local(input.item().elem()); + let lhs_0 = scope.create_local(elem); + let lhs_1 = scope.create_local(elem); + let lhs_2 = scope.create_local(elem); + let lhs_3 = scope.create_local(elem); gpu!(scope, lhs_0 = input[lhs_position_0]); gpu!(scope, lhs_1 = input[lhs_position_1]); gpu!(scope, lhs_2 = input[lhs_position_2]); @@ -274,20 +305,66 @@ impl MatmulTiling2dPaddedShader { } fn computation_loop( + &self, scope: &mut Scope, - block_size_k: Variable, // needed in computation, but also use attribute for unrolling + thread_col: Variable, thread_row: Variable, shared_lhs: Variable, - register_m: Variable, - block_size_n: Variable, - thread_col: Variable, shared_rhs: Variable, - register_n: Variable, // TM/TN use attribute for unrolling + register_m: Variable, + register_n: Variable, results: Variable, ) { + let block_size_k: Variable = self.config.block_size_k.into(); + let block_size_n: Variable = self.config.block_size_n.into(); + let elem = results.item().elem(); + + for dot_index in 0..self.config.block_size_k { + // Load a subcolumn of values from lhs + let lhs_sm_position = scope.create_local(Elem::UInt); + gpu!(scope, lhs_sm_position = thread_row / 4u32); + gpu!(scope, lhs_sm_position *= block_size_k); + gpu!(scope, lhs_sm_position += dot_index); + gpu!(scope, register_m = shared_lhs[lhs_sm_position]); + + // Load a subrow of values from rhs + let rhs_sm_position = scope.create_local(Elem::UInt); + gpu!(scope, rhs_sm_position = dot_index * block_size_n); + gpu!(scope, rhs_sm_position += thread_col); + gpu!(scope, rhs_sm_position = rhs_sm_position / 4u32); + gpu!(scope, register_n = shared_rhs[rhs_sm_position]); + + for res_idx_m in 0..self.config.tile_size_m { + for res_idx_n in 0..self.config.tile_size_n { + let registered_m = scope.create_local(elem); + let registered_n = scope.create_local(elem); + gpu!(scope, registered_m = register_m[res_idx_m]); + gpu!(scope, registered_n = register_n[res_idx_n]); + + let multiplied = scope.create_local(elem); + gpu!(scope, multiplied = registered_m * registered_n); + + let results_position = scope.create_local(Elem::UInt); + gpu!( + scope, + results_position = res_idx_m * self.config.tile_size_n + ); + gpu!(scope, results_position += res_idx_n); + + let results_before = scope.create_local(elem); + gpu!(scope, results_before = results[results_position]); + let results_after = scope.create_local(elem); + gpu!(scope, results_after = results_before + multiplied); + + gpu!(scope, results[results_position] = results_after); + } + } + } } } +// TODO try to reuse variables declared before the scope + impl DynamicKernelSource for MatmulTiling2dPaddedEagerKernel { fn source(&self) -> SourceTemplate { let mut scope = gpu::Scope::root(); From 7e07825acc670674f3c0b30e7ecde04b4f6162cc Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 13 Mar 2024 14:44:45 -0400 Subject: [PATCH 09/25] configurable unrolling --- .../src/codegen/dialect/gpu/branch.rs | 12 + .../src/codegen/dialect/gpu/macros.rs | 10 +- .../src/kernel/matmul/tiling2d_padded.rs | 307 ++++++++++-------- 3 files changed, 200 insertions(+), 129 deletions(-) diff --git a/crates/burn-jit/src/codegen/dialect/gpu/branch.rs b/crates/burn-jit/src/codegen/dialect/gpu/branch.rs index 1a617a8dd1..fdc3ea9ba4 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/branch.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/branch.rs @@ -119,3 +119,15 @@ impl Loop { parent_scope.register(Branch::Loop(op)); } } + +#[allow(missing_docs)] +pub struct UnrolledRangeLoop; + +impl UnrolledRangeLoop { + /// Registers an unrolled range loop to the given scope. + pub fn register(scope: &mut Scope, start: u32, end: u32, func: F) { + for i in start..end { + func(i.into(), scope); + } + } +} diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index eba8088341..7a2738a4f3 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -322,10 +322,18 @@ macro_rules! gpu { out: $out.into(), }); }; - // range(start, end).for_each(|scope| { ... }) + // range(start, end).for_each(|i, scope| { ... }) ($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => { $crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg); }; + // range(start, end, unroll).for_each(|i, scope| { ... }) + ($scope:expr, range($start:expr, $end:expr, $unroll:expr).for_each($arg:expr)) => { + if $unroll { + $crate::codegen::dialect::gpu::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), $arg); + } else { + $crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg); + } + }; // loop(|scope| { ... }) ($scope:expr, loop($arg:expr)) => { $crate::codegen::dialect::gpu::Loop::register($scope, $arg); diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index 02e0bc6639..39e3ee00ae 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -36,6 +36,7 @@ struct MatmulTiling2dPaddedEagerKernel { struct MatmulTiling2dPaddedShader { variables: BinaryOperator, config: Tiling2dConfig, + unroll: bool, } impl MatmulTiling2dPaddedShader { @@ -65,12 +66,12 @@ impl MatmulTiling2dPaddedShader { let rank = Variable::Rank; let penultimate_dim = scope.create_local(Elem::UInt); gpu!(scope, penultimate_dim = rank - 1u32); - let M = scope.create_local(Elem::UInt); - let K = scope.create_local(Elem::UInt); - let N = scope.create_local(Elem::UInt); - gpu!(scope, M = shape(lhs, penultimate_dim)); - gpu!(scope, K = shape(rhs, penultimate_dim)); - gpu!(scope, N = shape(rhs, rank)); + let dim_m = scope.create_local(Elem::UInt); + let dim_k = scope.create_local(Elem::UInt); + let dim_n = scope.create_local(Elem::UInt); + gpu!(scope, dim_m = shape(lhs, penultimate_dim)); + gpu!(scope, dim_k = shape(rhs, penultimate_dim)); + gpu!(scope, dim_n = shape(rhs, rank)); // Strides let lhs_stride_row = scope.create_local(Elem::UInt); @@ -119,7 +120,7 @@ impl MatmulTiling2dPaddedShader { // Batch offset for the output. let offset_output = scope.create_local(Elem::UInt); let batch_dims = scope.create_local(Elem::UInt); - gpu!(scope, offset_output = M * N); + gpu!(scope, offset_output = dim_m * dim_n); gpu!(scope, offset_output = offset_output * batch); // Batch offset for the lhs & rhs matrices. @@ -134,8 +135,6 @@ impl MatmulTiling2dPaddedShader { } .expand(scope); - // Phase 2: Loop over k for loading and computing - let elem = lhs.item().elem(); let results = scope.create_local_array(elem, results_size); let register_m = scope.create_local(Item::Vec4(elem)); @@ -150,7 +149,7 @@ impl MatmulTiling2dPaddedShader { ); let n_loops = scope.create_local(Elem::UInt); - gpu!(scope, n_loops = K / block_size_k); // assumes padding, otherwise ceil + gpu!(scope, n_loops = dim_k / block_size_k); // assumes padding, otherwise ceil gpu!( scope, range(0u32, n_loops).for_each(|i, scope| { @@ -158,8 +157,6 @@ impl MatmulTiling2dPaddedShader { let k = scope.create_local(Elem::UInt); gpu!(scope, k = i * block_size_k); - // Phase 2.1: Load to shared memory - // LHS self.load_shared_memory( scope, @@ -190,8 +187,6 @@ impl MatmulTiling2dPaddedShader { scope.register(Synchronization::WorkgroupBarrier); - // Phase 2.2: Compute intermediate results - self.computation_loop( scope, thread_col, thread_row, shared_lhs, shared_rhs, register_m, register_n, results, @@ -202,34 +197,16 @@ impl MatmulTiling2dPaddedShader { ); // Phase 3: Write to output - for res_idx_m in 0..self.config.tile_size_m { - for res_idx_n in 0..self.config.tile_size_n { - let results_position = scope.create_local(Elem::UInt); - gpu!( - scope, - results_position = res_idx_m * self.config.tile_size_n - ); - gpu!(scope, results_position += res_idx_n); - - let result = scope.create_local(elem); - gpu!(scope, result = results[results_position]); - - let output_position = scope.create_local(Elem::UInt); - let output_position_tmp1 = scope.create_local(Elem::UInt); - let output_position_tmp2 = scope.create_local(Elem::UInt); - gpu!(scope, output_position_tmp1 = row + res_idx_m); - gpu!(scope, output_position_tmp1 *= out_stride_row); - gpu!(scope, output_position_tmp2 = col + res_idx_n); - gpu!(scope, output_position_tmp2 *= out_stride_col); - gpu!( - scope, - output_position = output_position_tmp1 + output_position_tmp2 - ); - gpu!(scope, output_position += offset_output); - - gpu!(scope, out[output_position] = result); - } - } + self.write_to_output( + scope, + row, + col, + out_stride_row, + out_stride_col, + results, + offset_output, + out, + ); } fn load_shared_memory( @@ -249,59 +226,61 @@ impl MatmulTiling2dPaddedShader { let block_size_n: Variable = self.config.block_size_n.into(); let elem = input.item().elem(); - for j in 0u32..4u32 { - let current_col = scope.create_local(Elem::UInt); - gpu!(scope, current_col = thread_idx_1 + j); - - let aligned_with_shared_memory = scope.create_local(Elem::Bool); - gpu!( - scope, - aligned_with_shared_memory = current_col < block_size_k - ); + gpu!( + scope, + range(0_u32, 4u32, self.unroll).for_each(|j, scope| { + let current_col = scope.create_local(Elem::UInt); + gpu!(scope, current_col = thread_idx_1 + j); - gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ - let lhs_sm_position = scope.create_local(Elem::UInt); - if is_lhs { - gpu!(scope, lhs_sm_position = thread_idx_2 / 4u32); - gpu!(scope, lhs_sm_position *= block_size_k); - gpu!(scope, lhs_sm_position += current_col); - } else { - gpu!(scope, lhs_sm_position = current_col * block_size_n); - gpu!(scope, lhs_sm_position += thread_idx_2); - gpu!(scope, lhs_sm_position = lhs_sm_position / 4u32); - } + let aligned_with_shared_memory = scope.create_local(Elem::Bool); + gpu!( + scope, + aligned_with_shared_memory = current_col < block_size_k + ); - let lhs_position_0 = scope.create_local(Elem::UInt); - gpu!(scope, lhs_position_0 = k + current_col); - gpu!(scope, lhs_position_0 *= stride_1); - let tmp = scope.create_local(Elem::UInt); - gpu!(scope, tmp = thread_idx_2 * stride_2); - gpu!(scope, lhs_position_0 += tmp); - gpu!(scope, lhs_position_0 += input_offset); - let lhs_position_1 = scope.create_local(Elem::UInt); - let lhs_position_2 = scope.create_local(Elem::UInt); - let lhs_position_3 = scope.create_local(Elem::UInt); - gpu!(scope, lhs_position_1 = lhs_position_0 + stride_2); - gpu!(scope, lhs_position_2 = lhs_position_1 + stride_2); - gpu!(scope, lhs_position_3 = lhs_position_2 + stride_2); - - let lhs_0 = scope.create_local(elem); - let lhs_1 = scope.create_local(elem); - let lhs_2 = scope.create_local(elem); - let lhs_3 = scope.create_local(elem); - gpu!(scope, lhs_0 = input[lhs_position_0]); - gpu!(scope, lhs_1 = input[lhs_position_1]); - gpu!(scope, lhs_2 = input[lhs_position_2]); - gpu!(scope, lhs_3 = input[lhs_position_3]); - - let lhs_vec4 = scope.create_local(shared_memory.item()); - gpu!(scope, lhs_vec4 = vec4(lhs_0, lhs_1, lhs_2, lhs_3)); - gpu!(scope, shared_memory[lhs_sm_position] = lhs_vec4); - - }).else(|scope|{ - scope.register(Branch::Break); // TODO test if faster, else remove - })); - } + gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ + let lhs_sm_position = scope.create_local(Elem::UInt); + if is_lhs { + gpu!(scope, lhs_sm_position = thread_idx_2 / 4u32); + gpu!(scope, lhs_sm_position *= block_size_k); + gpu!(scope, lhs_sm_position += current_col); + } else { + gpu!(scope, lhs_sm_position = current_col * block_size_n); + gpu!(scope, lhs_sm_position += thread_idx_2); + gpu!(scope, lhs_sm_position = lhs_sm_position / 4u32); + } + + let lhs_position_0 = scope.create_local(Elem::UInt); + gpu!(scope, lhs_position_0 = k + current_col); + gpu!(scope, lhs_position_0 *= stride_1); + let tmp = scope.create_local(Elem::UInt); + gpu!(scope, tmp = thread_idx_2 * stride_2); + gpu!(scope, lhs_position_0 += tmp); + gpu!(scope, lhs_position_0 += input_offset); + let lhs_position_1 = scope.create_local(Elem::UInt); + let lhs_position_2 = scope.create_local(Elem::UInt); + let lhs_position_3 = scope.create_local(Elem::UInt); + gpu!(scope, lhs_position_1 = lhs_position_0 + stride_2); + gpu!(scope, lhs_position_2 = lhs_position_1 + stride_2); + gpu!(scope, lhs_position_3 = lhs_position_2 + stride_2); + + let lhs_0 = scope.create_local(elem); + let lhs_1 = scope.create_local(elem); + let lhs_2 = scope.create_local(elem); + let lhs_3 = scope.create_local(elem); + gpu!(scope, lhs_0 = input[lhs_position_0]); + gpu!(scope, lhs_1 = input[lhs_position_1]); + gpu!(scope, lhs_2 = input[lhs_position_2]); + gpu!(scope, lhs_3 = input[lhs_position_3]); + + let lhs_vec4 = scope.create_local(shared_memory.item()); + gpu!(scope, lhs_vec4 = vec4(lhs_0, lhs_1, lhs_2, lhs_3)); + gpu!(scope, shared_memory[lhs_sm_position] = lhs_vec4); + }).else(|scope|{ + scope.register(Branch::Break); // TODO test if faster, else remove + })); + }) + ); } fn computation_loop( @@ -319,47 +298,118 @@ impl MatmulTiling2dPaddedShader { let block_size_n: Variable = self.config.block_size_n.into(); let elem = results.item().elem(); - for dot_index in 0..self.config.block_size_k { - // Load a subcolumn of values from lhs - let lhs_sm_position = scope.create_local(Elem::UInt); - gpu!(scope, lhs_sm_position = thread_row / 4u32); - gpu!(scope, lhs_sm_position *= block_size_k); - gpu!(scope, lhs_sm_position += dot_index); - gpu!(scope, register_m = shared_lhs[lhs_sm_position]); - - // Load a subrow of values from rhs - let rhs_sm_position = scope.create_local(Elem::UInt); - gpu!(scope, rhs_sm_position = dot_index * block_size_n); - gpu!(scope, rhs_sm_position += thread_col); - gpu!(scope, rhs_sm_position = rhs_sm_position / 4u32); - gpu!(scope, register_n = shared_rhs[rhs_sm_position]); - - for res_idx_m in 0..self.config.tile_size_m { - for res_idx_n in 0..self.config.tile_size_n { - let registered_m = scope.create_local(elem); - let registered_n = scope.create_local(elem); - gpu!(scope, registered_m = register_m[res_idx_m]); - gpu!(scope, registered_n = register_n[res_idx_n]); - - let multiplied = scope.create_local(elem); - gpu!(scope, multiplied = registered_m * registered_n); - - let results_position = scope.create_local(Elem::UInt); + gpu!( + scope, + range(0u32, self.config.block_size_k as u32, self.unroll).for_each( + |dot_index, scope| { + // Load a subcolumn of values from lhs + let lhs_sm_position = scope.create_local(Elem::UInt); + gpu!(scope, lhs_sm_position = thread_row / 4u32); + gpu!(scope, lhs_sm_position *= block_size_k); + gpu!(scope, lhs_sm_position += dot_index); + gpu!(scope, register_m = shared_lhs[lhs_sm_position]); + + // Load a subrow of values from rhs + let rhs_sm_position = scope.create_local(Elem::UInt); + gpu!(scope, rhs_sm_position = dot_index * block_size_n); + gpu!(scope, rhs_sm_position += thread_col); + gpu!(scope, rhs_sm_position = rhs_sm_position / 4u32); + gpu!(scope, register_n = shared_rhs[rhs_sm_position]); + gpu!( scope, - results_position = res_idx_m * self.config.tile_size_n + range(0u32, self.config.tile_size_m as u32, self.unroll).for_each( + |res_idx_m, scope| { + gpu!( + scope, + range(0u32, self.config.tile_size_n as u32, self.unroll) + .for_each(|res_idx_n, scope| { + let registered_m = scope.create_local(elem); + let registered_n = scope.create_local(elem); + gpu!(scope, registered_m = register_m[res_idx_m]); + gpu!(scope, registered_n = register_n[res_idx_n]); + + let multiplied = scope.create_local(elem); + gpu!(scope, multiplied = registered_m * registered_n); + + let results_position = scope.create_local(Elem::UInt); + gpu!( + scope, + results_position = + res_idx_m * self.config.tile_size_n + ); + gpu!(scope, results_position += res_idx_n); + + let results_before = scope.create_local(elem); + gpu!(scope, results_before = results[results_position]); + let results_after = scope.create_local(elem); + gpu!( + scope, + results_after = results_before + multiplied + ); + + gpu!(scope, results[results_position] = results_after); + }) + ); + } + ) ); - gpu!(scope, results_position += res_idx_n); + } + ) + ); + } - let results_before = scope.create_local(elem); - gpu!(scope, results_before = results[results_position]); - let results_after = scope.create_local(elem); - gpu!(scope, results_after = results_before + multiplied); + fn write_to_output( + &self, + scope: &mut Scope, + row: Variable, + col: Variable, + out_stride_row: Variable, + out_stride_col: Variable, + results: Variable, + offset_output: Variable, + out: Variable, + ) { + let elem = results.item().elem(); - gpu!(scope, results[results_position] = results_after); + gpu!( + scope, + range(0u32, self.config.tile_size_m as u32, self.unroll).for_each( + |res_idx_m, scope| { + gpu!( + scope, + range(0u32, self.config.tile_size_n as u32, self.unroll).for_each( + |res_idx_n, scope| { + let results_position = scope.create_local(Elem::UInt); + gpu!( + scope, + results_position = res_idx_m * self.config.tile_size_n + ); + gpu!(scope, results_position += res_idx_n); + + let result = scope.create_local(elem); + gpu!(scope, result = results[results_position]); + + let output_position = scope.create_local(Elem::UInt); + let output_position_tmp1 = scope.create_local(Elem::UInt); + let output_position_tmp2 = scope.create_local(Elem::UInt); + gpu!(scope, output_position_tmp1 = row + res_idx_m); + gpu!(scope, output_position_tmp1 *= out_stride_row); + gpu!(scope, output_position_tmp2 = col + res_idx_n); + gpu!(scope, output_position_tmp2 *= out_stride_col); + gpu!( + scope, + output_position = output_position_tmp1 + output_position_tmp2 + ); + gpu!(scope, output_position += offset_output); + + gpu!(scope, out[output_position] = result); + } + ) + ); } - } - } + ) + ); } } @@ -377,6 +427,7 @@ impl DynamicKernelSource for MatmulTiling2dPaddedEagerKernel { MatmulTiling2dPaddedShader { variables: gpu::BinaryOperator { lhs, rhs, out }, config: self.config.clone(), + unroll: false, } .expand(&mut scope); From 855a77ca58cb81792927aef52a2e56f9c758ce15 Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 13 Mar 2024 17:59:29 -0400 Subject: [PATCH 10/25] not bugged --- crates/burn-jit/src/kernel/matmul/base.rs | 20 +- crates/burn-jit/src/kernel/matmul/simple.rs | 4 +- crates/burn-jit/src/kernel/matmul/tiling2d.rs | 10 +- .../src/kernel/matmul/tiling2d_padded.rs | 81 +++--- .../blocktiling_2d/generated_matmul copy.wgsl | 250 ++++++++++++++++++ .../blocktiling_2d/generated_matmul.wgsl | 211 +++++++++++++++ crates/burn-jit/src/tests/matmul.rs | 44 +++ 7 files changed, 579 insertions(+), 41 deletions(-) create mode 100644 crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul copy.wgsl create mode 100644 crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul.wgsl diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 2df48eceae..a4a18c8d5b 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -139,7 +139,7 @@ pub fn matmul( } } -pub(crate) fn launch_options( +pub(crate) fn simple_launch_options( lhs_shape: &Shape, rhs_shape: &Shape, output_shape: &Shape, @@ -159,3 +159,21 @@ pub(crate) fn launch_options( WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32) } + +pub(crate) fn tiling2d_launch_options( + output_shape: &Shape, + config: Tiling2dConfig, +) -> WorkGroup { + let num_rows = output_shape.dims[D - 2]; + let num_cols = output_shape.dims[D - 1]; + + // set number of workgroups + let blocks_needed_in_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32; + let blocks_needed_in_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32; + let mut num_iter = 1; + for i in 0..D - 2 { + num_iter *= output_shape.dims[i]; + } + + WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32) +} diff --git a/crates/burn-jit/src/kernel/matmul/simple.rs b/crates/burn-jit/src/kernel/matmul/simple.rs index f106e1482d..6bc4d2489e 100644 --- a/crates/burn-jit/src/kernel/matmul/simple.rs +++ b/crates/burn-jit/src/kernel/matmul/simple.rs @@ -13,7 +13,7 @@ use crate::{ }; use std::marker::PhantomData; -use super::launch_options; +use super::simple_launch_options; #[derive(new, Debug)] struct MatmulEagerKernel { @@ -228,7 +228,7 @@ pub fn matmul_simple( let lhs = into_contiguous(lhs); let rhs = into_contiguous(rhs); - let workgroup = launch_options( + let workgroup = simple_launch_options( &lhs.shape, &rhs.shape, &out.shape, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index c6af6d36e0..5372fdd7a4 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -13,7 +13,7 @@ use crate::{ }; use std::marker::PhantomData; -use super::{launch_options, Tiling2dConfig}; +use super::{tiling2d_launch_options, Tiling2dConfig}; #[derive(new, Debug)] struct MatmulTiling2d { @@ -231,12 +231,8 @@ pub fn matmul_tiling_2d( EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), ]) .outputs(&[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)]) - .execute(WorkgroupLaunch::Custom(launch_options( - &lhs.shape, - &rhs.shape, - &out.shape, - config.grid_x, - config.grid_y, + .execute(WorkgroupLaunch::Custom(tiling2d_launch_options( + &out.shape, config, ))); out diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index 39e3ee00ae..1a77fae146 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -6,10 +6,7 @@ use crate::{ Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, element::JitElement, - gpu::{ - gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Item, Scope, - Synchronization, Variable, - }, + gpu::{gpu, BinaryOperator, Branch, Elem, Item, Scope, Synchronization, Variable}, kernel::{into_contiguous, DynamicKernelSource, SourceTemplate}, tensor::JitTensor, Runtime, @@ -17,9 +14,8 @@ use crate::{ use std::marker::PhantomData; use super::{ - launch_options, padding::{crop, pad_round, PaddingOutput}, - shape_out, Tiling2dConfig, + shape_out, tiling2d_launch_options, Tiling2dConfig, }; #[derive(new, Debug)] @@ -64,14 +60,16 @@ impl MatmulTiling2dPaddedShader { // Shapes let rank = Variable::Rank; + let ultimate_dim = scope.create_local(Elem::UInt); let penultimate_dim = scope.create_local(Elem::UInt); - gpu!(scope, penultimate_dim = rank - 1u32); + gpu!(scope, ultimate_dim = rank - 1u32); + gpu!(scope, penultimate_dim = rank - 2u32); let dim_m = scope.create_local(Elem::UInt); let dim_k = scope.create_local(Elem::UInt); let dim_n = scope.create_local(Elem::UInt); gpu!(scope, dim_m = shape(lhs, penultimate_dim)); - gpu!(scope, dim_k = shape(rhs, penultimate_dim)); - gpu!(scope, dim_n = shape(rhs, rank)); + gpu!(scope, dim_k = shape(lhs, ultimate_dim)); + gpu!(scope, dim_n = shape(rhs, ultimate_dim)); // Strides let lhs_stride_row = scope.create_local(Elem::UInt); @@ -81,11 +79,11 @@ impl MatmulTiling2dPaddedShader { let out_stride_row = scope.create_local(Elem::UInt); let out_stride_col = scope.create_local(Elem::UInt); gpu!(scope, lhs_stride_row = stride(lhs, penultimate_dim)); - gpu!(scope, lhs_stride_col = stride(lhs, rank)); + gpu!(scope, lhs_stride_col = stride(lhs, ultimate_dim)); gpu!(scope, rhs_stride_row = stride(rhs, penultimate_dim)); - gpu!(scope, rhs_stride_col = stride(rhs, rank)); + gpu!(scope, rhs_stride_col = stride(rhs, ultimate_dim)); gpu!(scope, out_stride_row = stride(out, penultimate_dim)); - gpu!(scope, out_stride_col = stride(out, rank)); + gpu!(scope, out_stride_col = stride(out, ultimate_dim)); // Workgroup offset let skip_row = scope.create_local(Elem::UInt); @@ -125,15 +123,34 @@ impl MatmulTiling2dPaddedShader { // Batch offset for the lhs & rhs matrices. gpu!(scope, batch_dims = rank - 2u32); - IndexOffsetGlobalWithLayout { - tensors: vec![lhs, rhs], - indexes: vec![offset_lhs, offset_rhs], - layout: out, - index_ref: offset_output, - dim_start: 0u32.into(), - dim_end: batch_dims, - } - .expand(scope); + gpu!( + scope, + range(0u32, batch_dims).for_each(|b, scope| { + let stride_lhs = scope.create_local(Elem::UInt); + let stride_rhs = scope.create_local(Elem::UInt); + let stride_output = scope.create_local(Elem::UInt); + let shape_lhs = scope.create_local(Elem::UInt); + let shape_rhs = scope.create_local(Elem::UInt); + + gpu!(scope, stride_lhs = stride(lhs, b)); + gpu!(scope, stride_rhs = stride(rhs, b)); + gpu!(scope, stride_output = stride(out, b)); + gpu!(scope, shape_lhs = shape(lhs, b)); + gpu!(scope, shape_rhs = shape(rhs, b)); + + let tmp = scope.create_local(Elem::UInt); + gpu!(scope, tmp = offset_output / stride_output); + let tmp_lhs = scope.create_local(Elem::UInt); + gpu!(scope, tmp_lhs = tmp % shape_lhs); + gpu!(scope, tmp_lhs = tmp_lhs * stride_lhs); + gpu!(scope, offset_lhs += tmp_lhs); + + let tmp_rhs = scope.create_local(Elem::UInt); + gpu!(scope, tmp_rhs = tmp % shape_rhs); + gpu!(scope, tmp_rhs = tmp_rhs * stride_rhs); + gpu!(scope, offset_rhs += tmp_rhs); + }) + ); let elem = lhs.item().elem(); let results = scope.create_local_array(elem, results_size); @@ -206,6 +223,7 @@ impl MatmulTiling2dPaddedShader { results, offset_output, out, + batch, ); } @@ -276,9 +294,10 @@ impl MatmulTiling2dPaddedShader { let lhs_vec4 = scope.create_local(shared_memory.item()); gpu!(scope, lhs_vec4 = vec4(lhs_0, lhs_1, lhs_2, lhs_3)); gpu!(scope, shared_memory[lhs_sm_position] = lhs_vec4); - }).else(|scope|{ - scope.register(Branch::Break); // TODO test if faster, else remove })); + // }).else(|scope|{ + // scope.register(Branch::Break); // TODO test if faster, else remove + // })); }) ); } @@ -369,6 +388,7 @@ impl MatmulTiling2dPaddedShader { results: Variable, offset_output: Variable, out: Variable, + tmp: Variable, ) { let elem = results.item().elem(); @@ -403,6 +423,7 @@ impl MatmulTiling2dPaddedShader { ); gpu!(scope, output_position += offset_output); + // gpu!(scope, out[output_position] = tmp); gpu!(scope, out[output_position] = result); } ) @@ -413,8 +434,6 @@ impl MatmulTiling2dPaddedShader { } } -// TODO try to reuse variables declared before the scope - impl DynamicKernelSource for MatmulTiling2dPaddedEagerKernel { fn source(&self) -> SourceTemplate { let mut scope = gpu::Scope::root(); @@ -497,6 +516,8 @@ pub fn matmul_tiling_2d_padded round_rhs.into_tensor(), }; + println!("{:?}", lhs.shape); + println!("{:?}", rhs.shape); let rounded_output_shape = shape_out(&lhs, &rhs); @@ -508,6 +529,7 @@ pub fn matmul_tiling_2d_padded input_0_global: array; + +@group(0) +@binding(1) +var input_1_global: array; + +@group(0) +@binding(2) +var output_0_global: array; + +@group(0) +@binding(3) +var info: array; + +var shared_memory_0: array, 512>; + +var shared_memory_1: array, 512>; + +const WORKGROUP_SIZE_X = 16u; +const WORKGROUP_SIZE_Y = 16u; +const WORKGROUP_SIZE_Z = 1u; + +@compute +@workgroup_size(16, 16, 1) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_index) local_idx: u32, + @builtin(workgroup_id) workgroup_id: vec3, +) { + var a_0_0: array; + + let rank: u32 = info[0]; + let rank_2: u32 = rank * 2u; + var l_0_0: u32; + var l_0_1: u32; + var l_0_2: u32; + var l_0_3: u32; + var l_0_4: u32; + var l_0_5: u32; + var l_0_6: u32; + var l_0_7: u32; + var l_0_8: u32; + var l_0_9: u32; + var l_0_10: u32; + var l_0_11: u32; + var l_0_12: u32; + var l_0_13: u32; + var l_0_14: u32; + var l_0_15: u32; + var l_0_16: u32; + var l_0_17: u32; + var l_0_18: u32; + var l_0_19: u32; + var l_0_20: vec4; + var l_0_21: vec4; + var l_0_22: u32; + l_0_0 = rank - 1u; + l_0_1 = info[(0u * rank_2) + rank + l_0_0 + 1u]; + l_0_2 = info[(1u * rank_2) + rank + l_0_0 + 1u]; + l_0_3 = info[(1u * rank_2) + rank + rank + 1u]; + l_0_4 = info[(0u * rank_2) + l_0_0 + 1u]; + l_0_5 = info[(0u * rank_2) + rank + 1u]; + l_0_6 = info[(1u * rank_2) + l_0_0 + 1u]; + l_0_7 = info[(1u * rank_2) + rank + 1u]; + l_0_8 = info[(2u * rank_2) + l_0_0 + 1u]; + l_0_9 = info[(2u * rank_2) + rank + 1u]; + l_0_10 = u32(workgroup_id.x); + l_0_10 = l_0_10 * 64u; + l_0_11 = u32(workgroup_id.y); + l_0_11 = l_0_11 * 64u; + l_0_12 = local_idx / 16u; + l_0_12 = l_0_12 * 4u; + l_0_13 = local_idx % 16u; + l_0_13 = l_0_13 * 4u; + l_0_14 = l_0_10 + l_0_12; + l_0_15 = l_0_11 + l_0_13; + l_0_16 = l_0_10 * l_0_4; + l_0_17 = l_0_11 * l_0_7; + l_0_18 = l_0_1 * l_0_3; + l_0_18 = l_0_18 * global_id.z; + l_0_19 = rank - 2u; + l_0_16 = u32(0u); + l_0_17 = u32(0u); + + for (var l_1_0: u32 = 0u; l_1_0 < l_0_19; l_1_0++) { + var l_1_1: u32; + var l_1_2: u32; + var l_1_3: u32; + var l_1_4: u32; + var l_1_5: u32; + var l_1_6: u32; + var l_1_7: u32; + var l_1_8: u32; + l_1_1 = info[(2u * rank_2) + l_1_0 + 1u]; + l_1_2 = l_0_18 * 1u; + l_1_2 = l_1_2 / l_1_1; + l_1_3 = info[(0u * rank_2) + l_1_0 + 1u]; + l_1_4 = info[(0u * rank_2) + rank + l_1_0 + 1u]; + l_1_5 = l_1_2 % l_1_4; + l_1_5 = l_1_5 * l_1_3; + l_0_16 = l_0_16 + l_1_5; + l_1_6 = info[(1u * rank_2) + l_1_0 + 1u]; + l_1_7 = info[(1u * rank_2) + rank + l_1_0 + 1u]; + l_1_8 = l_1_2 % l_1_7; + l_1_8 = l_1_8 * l_1_6; + l_0_17 = l_0_17 + l_1_8; + } + l_0_16 = l_0_16 / 1u; + l_0_17 = l_0_17 / 1u; + l_0_22 = l_0_2 / 32u; + + for (var l_1_0: u32 = 0u; l_1_0 < l_0_22; l_1_0++) { + var l_1_1: u32; + l_1_1 = l_1_0 * 32u; + + for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { + var l_2_1: u32; + var l_2_2: bool; + l_2_1 = l_0_13 + l_2_0; + l_2_2 = l_2_1 < 32u; + if l_2_2 { + var l_3_0: u32; + var l_3_1: u32; + var l_3_2: u32; + var l_3_3: u32; + var l_3_4: u32; + var l_3_5: u32; + var l_3_6: f32; + var l_3_7: f32; + var l_3_8: f32; + var l_3_9: f32; + var l_3_10: vec4; + l_3_0 = l_0_12 / 4u; + l_3_0 = l_3_0 * 32u; + l_3_0 = l_3_0 + l_2_1; + l_3_1 = l_1_1 + l_2_1; + l_3_1 = l_3_1 * l_0_5; + l_3_2 = l_0_12 * l_0_4; + l_3_1 = l_3_1 + l_3_2; + l_3_1 = l_3_1 + l_0_16; + l_3_3 = l_3_1 + l_0_4; + l_3_4 = l_3_3 + l_0_4; + l_3_5 = l_3_4 + l_0_4; + l_3_6 = f32(input_0_global[l_3_1]); + l_3_7 = f32(input_0_global[l_3_3]); + l_3_8 = f32(input_0_global[l_3_4]); + l_3_9 = f32(input_0_global[l_3_5]); + l_3_10 = vec4(l_3_6, l_3_7, l_3_8, l_3_9); + shared_memory_0[l_3_0] = vec4(l_3_10); + } else { + break; + } + } + + for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { + var l_2_1: u32; + var l_2_2: bool; + l_2_1 = l_0_12 + l_2_0; + l_2_2 = l_2_1 < 32u; + if l_2_2 { + var l_3_0: u32; + var l_3_1: u32; + var l_3_2: u32; + var l_3_3: u32; + var l_3_4: u32; + var l_3_5: u32; + var l_3_6: f32; + var l_3_7: f32; + var l_3_8: f32; + var l_3_9: f32; + var l_3_10: vec4; + l_3_0 = l_2_1 * 64u; + l_3_0 = l_3_0 + l_0_13; + l_3_0 = l_3_0 / 4u; + l_3_1 = l_1_1 + l_2_1; + l_3_1 = l_3_1 * l_0_6; + l_3_2 = l_0_13 * l_0_7; + l_3_1 = l_3_1 + l_3_2; + l_3_1 = l_3_1 + l_0_17; + l_3_3 = l_3_1 + l_0_7; + l_3_4 = l_3_3 + l_0_7; + l_3_5 = l_3_4 + l_0_7; + l_3_6 = f32(input_1_global[l_3_1]); + l_3_7 = f32(input_1_global[l_3_3]); + l_3_8 = f32(input_1_global[l_3_4]); + l_3_9 = f32(input_1_global[l_3_5]); + l_3_10 = vec4(l_3_6, l_3_7, l_3_8, l_3_9); + shared_memory_1[l_3_0] = vec4(l_3_10); + } else { + break; + } + } + workgroupBarrier(); + + for (var l_2_0: u32 = 0u; l_2_0 < 32u; l_2_0++) { + var l_2_1: u32; + var l_2_2: u32; + l_2_1 = l_0_12 / 4u; + l_2_1 = l_2_1 * 32u; + l_2_1 = l_2_1 + l_2_0; + l_0_20 = vec4(shared_memory_0[l_2_1]); + l_2_2 = l_2_0 * 64u; + l_2_2 = l_2_2 + l_0_13; + l_2_2 = l_2_2 / 4u; + l_0_21 = vec4(shared_memory_1[l_2_2]); + + for (var l_3_0: u32 = 0u; l_3_0 < 4u; l_3_0++) { + for (var l_4_0: u32 = 0u; l_4_0 < 4u; l_4_0++) { + var l_4_1: f32; + var l_4_2: f32; + var l_4_3: f32; + var l_4_4: u32; + var l_4_5: f32; + var l_4_6: f32; + l_4_1 = f32(l_0_20[l_3_0]); + l_4_2 = f32(l_0_21[l_4_0]); + l_4_3 = l_4_1 * l_4_2; + l_4_4 = l_3_0 * 4u; + l_4_4 = l_4_4 + l_4_0; + l_4_5 = f32(a_0_0[l_4_4]); + l_4_6 = l_4_5 + l_4_3; + a_0_0[l_4_4] = f32(l_4_6); + } + } + } + workgroupBarrier(); + } + + for (var l_1_0: u32 = 0u; l_1_0 < 4u; l_1_0++) { + for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { + var l_2_1: u32; + var l_2_2: f32; + var l_2_3: u32; + var l_2_4: u32; + var l_2_5: u32; + l_2_1 = l_1_0 * 4u; + l_2_1 = l_2_1 + l_2_0; + l_2_2 = f32(a_0_0[l_2_1]); + l_2_4 = l_0_14 + l_1_0; + l_2_4 = l_2_4 * l_0_8; + l_2_5 = l_0_15 + l_2_0; + l_2_5 = l_2_5 * l_0_9; + l_2_3 = l_2_4 + l_2_5; + l_2_3 = l_2_3 + l_0_18; + output_0_global[l_2_3] = f32(l_2_2); + } + } +} \ No newline at end of file diff --git a/crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul.wgsl b/crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul.wgsl new file mode 100644 index 0000000000..1b4021f142 --- /dev/null +++ b/crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul.wgsl @@ -0,0 +1,211 @@ +@group(0) +@binding(0) +var input_0_global: array; + +@group(0) +@binding(1) +var input_1_global: array; + +@group(0) +@binding(2) +var output_0_global: array; + +@group(0) +@binding(3) +var info: array; + +var shared_memory_0: array, 512>; + +var shared_memory_1: array, 512>; + +const WORKGROUP_SIZE_X = 16u; +const WORKGROUP_SIZE_Y = 16u; +const WORKGROUP_SIZE_Z = 1u; + +@compute +@workgroup_size(16, 16, 1) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_index) local_idx: u32, + @builtin(workgroup_id) workgroup_id: vec3, +) { + var results: array; + + let rank: u32 = info[0]; + let rank_2: u32 = rank * 2u; + var l_0_0: u32; + var M: u32; + var K: u32; + var N: u32; + var lhs_stride_row: u32; + var lhs_stride_col: u32; + var rhs_stride_row: u32; + var rhs_stride_col: u32; + var out_stride_row: u32; + var out_stride_col: u32; + var skip_row: u32; + var skip_col: u32; + var thread_row: u32; + var thread_col: u32; + var row: u32; + var col: u32; + var offset_lhs: u32; + var offset_rhs: u32; + var offset_output: u32; + var l_0_19: u32; + var register_m: vec4; + var register_n: vec4; + var n_loops: u32; + l_0_0 = rank - 1u; + M = info[(0u * rank_2) + rank + l_0_0 + 1u]; + K = info[(1u * rank_2) + rank + l_0_0 + 1u]; + N = info[(1u * rank_2) + rank + rank + 1u]; + lhs_stride_row = info[(0u * rank_2) + l_0_0 + 1u]; + lhs_stride_col = info[(0u * rank_2) + rank + 1u]; + rhs_stride_row = info[(1u * rank_2) + l_0_0 + 1u]; + rhs_stride_col = info[(1u * rank_2) + rank + 1u]; + out_stride_row = info[(2u * rank_2) + l_0_0 + 1u]; + out_stride_col = info[(2u * rank_2) + rank + 1u]; + skip_row = u32(workgroup_id.x) * 64u; + skip_col = u32(workgroup_id.y) * 64u; + thread_row = (local_idx / 16u)*4u; + thread_col = (local_idx % 16u)*4u; + row = skip_row + thread_row; + col = skip_col + thread_col; + offset_lhs = skip_row * lhs_stride_row; + offset_rhs = skip_col * rhs_stride_col; + offset_output = M * N; + offset_output = offset_output * global_id.z; + l_0_19 = rank - 2u; + + for (var l_1_0: u32 = 0u; l_1_0 < l_0_19; l_1_0++) { + var l_1_1: u32; + var l_1_2: u32; + var l_1_3: u32; + var l_1_4: u32; + var l_1_5: u32; + var l_1_6: u32; + var l_1_7: u32; + var l_1_8: u32; + l_1_1 = info[(0u * rank_2) + l_1_0 + 1u]; + l_1_2 = info[(1u * rank_2) + l_1_0 + 1u]; + l_1_3 = info[(2u * rank_2) + l_1_0 + 1u]; + l_1_4 = info[(0u * rank_2) + rank + l_1_0 + 1u]; + l_1_5 = info[(1u * rank_2) + rank + l_1_0 + 1u]; + l_1_6 = offset_output / l_1_3; + l_1_7 = l_1_6 % l_1_4; + l_1_7 = l_1_7 * l_1_1; + offset_lhs = offset_lhs + l_1_7; + l_1_8 = l_1_6 % l_1_5; + l_1_8 = l_1_8 * l_1_2; + offset_rhs = offset_rhs + l_1_8; + } + n_loops = K / 32u; + + for (var i: u32 = 0u; i < n_loops; i++) { + var k: u32; + k = i * 32u; + + for (var j: u32 = 0u; j < 4u; j++) { + var current_col: u32; + var l_2_2: bool; + current_col = thread_col + j; + if current_col < 32u{ + var l_3_0: u32; + var l_3_1: u32; + var l_3_2: u32; + var l_3_3: u32; + var l_3_4: u32; + var l_3_5: u32; + var l_3_6: f32; + var l_3_7: f32; + var l_3_8: f32; + var l_3_9: f32; + var l_3_10: vec4; + l_3_0 = (thread_row / 4u) * 32u + current_col; + lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row; + lhs_position1 = lhs_position0 + lhs_stride_row; + lhs_position2 = lhs_position1 + lhs_stride_row; + lhs_position3 = lhs_position2 + lhs_stride_row; + l_3_6 = f32(input_0_global[lhs_position0]); + l_3_7 = f32(input_0_global[lhs_position1]); + l_3_8 = f32(input_0_global[lhs_position2]); + l_3_9 = f32(input_0_global[lhs_position3]); + l_3_10 = vec4(l_3_6, l_3_7, l_3_8, l_3_9); + shared_memory_0[l_3_0] = vec4(l_3_10); + } else { + break; + } + } + + for (var i: u32 = 0u; i < 4u; i++) { + var current_row: u32; + current_row = thread_row + i; + if current_row < 32u { + var l_3_0: u32; + var rhs_position0: u32; + var l_3_2: u32; + var l_3_3: u32; + var l_3_4: u32; + var l_3_5: u32; + var l_3_6: f32; + var l_3_7: f32; + var l_3_8: f32; + var l_3_9: f32; + var l_3_10: vec4; + rhs_sm_position = (current_row * 64u + thread_col) / 4u; + rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col; + rhs_position1 = rhs_position0 + rhs_stride_col; + rhs_position2 = rhs_position1 + rhs_stride_col; + rhs_position3 = rhs_position2 + rhs_stride_col; + l_3_6 = f32(input_1_global[rhs_position0]); + l_3_7 = f32(input_1_global[rhs_position1]); + l_3_8 = f32(input_1_global[rhs_position2]); + l_3_9 = f32(input_1_global[rhs_position3]); + l_3_10 = vec4(l_3_6, l_3_7, l_3_8, l_3_9); + shared_memory_1[rhs_sm_position] = vec4(l_3_10); + } else { + break; + } + } + workgroupBarrier(); + + for (var dot_index: u32 = 0u; dot_index < 32u; dot_index++) { + var lhs_sm_position: u32; + var rhs_sm_position: u32; + lhs_sm_position = (thread_row / 4u)*32u+dot_index; + register_m = vec4(shared_memory_0[lhs_sm_position]); + rhs_sm_position = (dot_index * 64u + thread_col) / 4u; + register_n = vec4(shared_memory_1[rhs_sm_position]); + + for (var res_idx_m: u32 = 0u; res_idx_m < 4u; res_idx_m++) { + for (var res_idx_n: u32 = 0u; res_idx_n < 4u; res_idx_n++) { + var left: f32; + var right: f32; + var multiplied: f32; + var results_position: u32; + var old: f32; + left = f32(register_m[res_idx_m]); + right = f32(register_n[res_idx_n]); + multiplied = left * right; + results_position = res_idx_m * 4u + res_idx_n; + old = f32(results[results_position]); + results[results_position] = f32(old + multiplied); + } + } + } + workgroupBarrier(); + } + + for (var res_idx_m: u32 = 0u; res_idx_m < 4u; res_idx_m++) { + for (var res_idx_n: u32 = 0u; res_idx_n < 4u; res_idx_n++) { + var result_position: u32; + var result: f32; + var output_position: u32; + result_position = res_idx_m * 4u + res_idx_n; + result = f32(results[result_position]); + output_position = (row + res_idx_m) * out_stride_row + (col + res_idx_n) * out_stride_col + offset_output; + output_0_global[output_position] = f32(result); + } + } +} \ No newline at end of file diff --git a/crates/burn-jit/src/tests/matmul.rs b/crates/burn-jit/src/tests/matmul.rs index e91e5cbcc2..1ca5c5b0e3 100644 --- a/crates/burn-jit/src/tests/matmul.rs +++ b/crates/burn-jit/src/tests/matmul.rs @@ -212,6 +212,48 @@ mod tests { ); } + #[test] + fn stable_test() { + let ref_tensor_device = Default::default(); + let x = ReferenceTensor::from_floats([[0., 1., 2.], [3., 4., 5.]], &ref_tensor_device); + let y = + ReferenceTensor::from_floats([[0., 1.], [2., 3.], [4., 5.]], &ref_tensor_device); + + let test_tensor_device = Default::default(); + let x_jit = TestTensor::from_data(x.to_data(), &test_tensor_device); + let y_jit = TestTensor::from_data(y.to_data(), &test_tensor_device); + + let z_reference = x.matmul(y); + let z = Tensor::::from_primitive(matmul( + x_jit.into_primitive(), + y_jit.into_primitive(), + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), + )); + + z_reference.into_data().assert_approx_eq(&z.into_data(), 3); + } + + #[test] + fn stable_test_2() { + let ref_tensor_device = Default::default(); + let x = + ReferenceTensor::from_floats([[0., 1.], [2., 3.], [4., 5.]], &ref_tensor_device); + let y = ReferenceTensor::from_floats([[0., 1., 2.], [3., 4., 5.]], &ref_tensor_device); + + let test_tensor_device = Default::default(); + let x_jit = TestTensor::from_data(x.to_data(), &test_tensor_device); + let y_jit = TestTensor::from_data(y.to_data(), &test_tensor_device); + + let z_reference = x.matmul(y); + let z = Tensor::::from_primitive(matmul( + x_jit.into_primitive(), + y_jit.into_primitive(), + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), + )); + + z_reference.into_data().assert_approx_eq(&z.into_data(), 3); + } + fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { let shape_lhs = [batch_1, batch_2, m, k]; let shape_rhs = [batch_1, batch_2, k, n]; @@ -635,6 +677,8 @@ mod tests { y_jit.into_primitive(), strategy, )); + println!("{}", z_reference); + println!("{}", z); z_reference.into_data().assert_approx_eq(&z.into_data(), 3); } From 11f55fdd889c5a8bdddabd2bab0986c541a0f80e Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 13 Mar 2024 18:09:06 -0400 Subject: [PATCH 11/25] fails on unroll --- crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index 1a77fae146..7d776fe9b2 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -223,7 +223,6 @@ impl MatmulTiling2dPaddedShader { results, offset_output, out, - batch, ); } @@ -294,10 +293,9 @@ impl MatmulTiling2dPaddedShader { let lhs_vec4 = scope.create_local(shared_memory.item()); gpu!(scope, lhs_vec4 = vec4(lhs_0, lhs_1, lhs_2, lhs_3)); gpu!(scope, shared_memory[lhs_sm_position] = lhs_vec4); + }).else(|scope|{ + scope.register(Branch::Break); // TODO test if faster, else remove })); - // }).else(|scope|{ - // scope.register(Branch::Break); // TODO test if faster, else remove - // })); }) ); } @@ -388,7 +386,6 @@ impl MatmulTiling2dPaddedShader { results: Variable, offset_output: Variable, out: Variable, - tmp: Variable, ) { let elem = results.item().elem(); @@ -446,7 +443,7 @@ impl DynamicKernelSource for MatmulTiling2dPaddedEagerKernel { MatmulTiling2dPaddedShader { variables: gpu::BinaryOperator { lhs, rhs, out }, config: self.config.clone(), - unroll: false, + unroll: true, } .expand(&mut scope); @@ -516,8 +513,6 @@ pub fn matmul_tiling_2d_padded round_rhs.into_tensor(), }; - println!("{:?}", lhs.shape); - println!("{:?}", rhs.shape); let rounded_output_shape = shape_out(&lhs, &rhs); @@ -529,7 +524,6 @@ pub fn matmul_tiling_2d_padded Date: Wed, 13 Mar 2024 18:26:17 -0400 Subject: [PATCH 12/25] stupid break --- .../src/kernel/matmul/tiling2d_padded.rs | 66 +++++++++---------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs index 7d776fe9b2..8cb59fa193 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs @@ -37,8 +37,6 @@ struct MatmulTiling2dPaddedShader { impl MatmulTiling2dPaddedShader { fn expand(self, scope: &mut Scope) { - // Phase 1: Gather information: input, shader and offsets - // Inputs let lhs = self.variables.lhs; let rhs = self.variables.rhs; @@ -256,45 +254,43 @@ impl MatmulTiling2dPaddedShader { ); gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ - let lhs_sm_position = scope.create_local(Elem::UInt); + let sm_position = scope.create_local(Elem::UInt); if is_lhs { - gpu!(scope, lhs_sm_position = thread_idx_2 / 4u32); - gpu!(scope, lhs_sm_position *= block_size_k); - gpu!(scope, lhs_sm_position += current_col); + gpu!(scope, sm_position = thread_idx_2 / 4u32); + gpu!(scope, sm_position *= block_size_k); + gpu!(scope, sm_position += current_col); } else { - gpu!(scope, lhs_sm_position = current_col * block_size_n); - gpu!(scope, lhs_sm_position += thread_idx_2); - gpu!(scope, lhs_sm_position = lhs_sm_position / 4u32); + gpu!(scope, sm_position = current_col * block_size_n); + gpu!(scope, sm_position += thread_idx_2); + gpu!(scope, sm_position = sm_position / 4u32); } - let lhs_position_0 = scope.create_local(Elem::UInt); - gpu!(scope, lhs_position_0 = k + current_col); - gpu!(scope, lhs_position_0 *= stride_1); + let position_0 = scope.create_local(Elem::UInt); + gpu!(scope, position_0 = k + current_col); + gpu!(scope, position_0 *= stride_1); let tmp = scope.create_local(Elem::UInt); gpu!(scope, tmp = thread_idx_2 * stride_2); - gpu!(scope, lhs_position_0 += tmp); - gpu!(scope, lhs_position_0 += input_offset); - let lhs_position_1 = scope.create_local(Elem::UInt); - let lhs_position_2 = scope.create_local(Elem::UInt); - let lhs_position_3 = scope.create_local(Elem::UInt); - gpu!(scope, lhs_position_1 = lhs_position_0 + stride_2); - gpu!(scope, lhs_position_2 = lhs_position_1 + stride_2); - gpu!(scope, lhs_position_3 = lhs_position_2 + stride_2); - - let lhs_0 = scope.create_local(elem); - let lhs_1 = scope.create_local(elem); - let lhs_2 = scope.create_local(elem); - let lhs_3 = scope.create_local(elem); - gpu!(scope, lhs_0 = input[lhs_position_0]); - gpu!(scope, lhs_1 = input[lhs_position_1]); - gpu!(scope, lhs_2 = input[lhs_position_2]); - gpu!(scope, lhs_3 = input[lhs_position_3]); - - let lhs_vec4 = scope.create_local(shared_memory.item()); - gpu!(scope, lhs_vec4 = vec4(lhs_0, lhs_1, lhs_2, lhs_3)); - gpu!(scope, shared_memory[lhs_sm_position] = lhs_vec4); - }).else(|scope|{ - scope.register(Branch::Break); // TODO test if faster, else remove + gpu!(scope, position_0 += tmp); + gpu!(scope, position_0 += input_offset); + let position_1 = scope.create_local(Elem::UInt); + let position_2 = scope.create_local(Elem::UInt); + let position_3 = scope.create_local(Elem::UInt); + gpu!(scope, position_1 = position_0 + stride_2); + gpu!(scope, position_2 = position_1 + stride_2); + gpu!(scope, position_3 = position_2 + stride_2); + + let val_0 = scope.create_local(elem); + let val_1 = scope.create_local(elem); + let val_2 = scope.create_local(elem); + let val_3 = scope.create_local(elem); + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + gpu!(scope, val_2 = input[position_2]); + gpu!(scope, val_3 = input[position_3]); + + let val_vec4 = scope.create_local(shared_memory.item()); + gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); + gpu!(scope, shared_memory[sm_position] = val_vec4); })); }) ); From a0bb940af6a970ff70db00528e834767bd5a4c6c Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 14 Mar 2024 11:22:58 -0400 Subject: [PATCH 13/25] tiling2d no assumption works --- .../src/codegen/dialect/gpu/operation.rs | 1 + .../src/codegen/dialect/gpu/vectorization.rs | 1 + crates/burn-jit/src/fusion/tracing/builder.rs | 5 + crates/burn-jit/src/kernel/matmul/base.rs | 11 +- crates/burn-jit/src/kernel/matmul/mod.rs | 4 +- crates/burn-jit/src/kernel/matmul/tiling2d.rs | 219 +++----- ...{tiling2d_padded.rs => tiling2d_shader.rs} | 507 +++++++++++------- .../blocktiling_2d/generated_matmul copy.wgsl | 250 --------- .../blocktiling_2d/generated_matmul.wgsl | 211 -------- crates/burn-jit/src/tensor/base.rs | 1 - crates/burn-jit/src/tests/matmul.rs | 2 - .../burn-wgpu/src/compiler/wgsl/compiler.rs | 4 + .../src/compiler/wgsl/instructions.rs | 7 + crates/burn-wgpu/src/compute/server.rs | 1 - 14 files changed, 414 insertions(+), 810 deletions(-) rename crates/burn-jit/src/kernel/matmul/{tiling2d_padded.rs => tiling2d_shader.rs} (50%) delete mode 100644 crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul copy.wgsl delete mode 100644 crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul.wgsl diff --git a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs index bad960a1b8..37193854d1 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs @@ -36,6 +36,7 @@ pub enum Operator { Tanh(UnaryOperator), Powf(BinaryOperator), Sqrt(UnaryOperator), + Ceil(UnaryOperator), Erf(UnaryOperator), Recip(UnaryOperator), Equal(BinaryOperator), diff --git a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs index 936358cf45..f9a42379de 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs @@ -55,6 +55,7 @@ impl Operator { Operator::Tanh(op) => Operator::Tanh(op.vectorize(vectorization)), Operator::Powf(op) => Operator::Powf(op.vectorize(vectorization)), Operator::Sqrt(op) => Operator::Sqrt(op.vectorize(vectorization)), + Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)), Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)), Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)), Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)), diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index 0172b83ffa..a9eaa2bf0f 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -247,6 +247,11 @@ impl TraceBuilder { &mut local_tensor_ids_input, &mut local_tensor_ids_output, ), + gpu::Operator::Ceil(op) => mark_unary( + op, + &mut local_tensor_ids_input, + &mut local_tensor_ids_output, + ), gpu::Operator::Log(op) => mark_unary( op, &mut local_tensor_ids_input, diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index a4a18c8d5b..301968c493 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -5,10 +5,17 @@ use burn_tensor::Shape; use crate::{compute::WorkGroup, tensor::JitTensor, Compiler, JitElement, Runtime}; use super::{ - init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, - tiling2d_padded::matmul_tiling_2d_padded, + init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, matmul_tiling_2d_padded, }; +#[derive(Debug, Clone)] +pub(crate) enum Tiling2DAssumption { + // Input shapes are divisible by their corresponding block sizes + Round, + // Bounds must be checked + None, +} + #[derive(Debug, Clone)] /// Tiling 2D parameters pub struct Tiling2dConfig { diff --git a/crates/burn-jit/src/kernel/matmul/mod.rs b/crates/burn-jit/src/kernel/matmul/mod.rs index 9c405e060d..ba4830b2e1 100644 --- a/crates/burn-jit/src/kernel/matmul/mod.rs +++ b/crates/burn-jit/src/kernel/matmul/mod.rs @@ -1,7 +1,7 @@ mod base; mod simple; mod tiling2d; -mod tiling2d_padded; +mod tiling2d_shader; mod tune; /// Contains utilitary for matmul operation @@ -20,4 +20,4 @@ pub mod padding; mod padding; pub use tiling2d::*; -pub use tiling2d_padded::*; +use tiling2d_shader::*; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index 5372fdd7a4..ed43fa0cf9 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -1,4 +1,4 @@ -use burn_tensor::Element; +use burn_tensor::{Element, Shape}; use crate::{ codegen::{ @@ -6,14 +6,16 @@ use crate::{ Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, element::JitElement, - gpu::{gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Scope, Variable}, kernel::{into_contiguous, DynamicKernelSource, SourceTemplate}, tensor::JitTensor, Runtime, }; use std::marker::PhantomData; -use super::{tiling2d_launch_options, Tiling2dConfig}; +use super::{ + padding::{crop, pad_round, PaddingOutput}, + shape_out, tiling2d_launch_options, MatmulTiling2dShader, Tiling2DAssumption, Tiling2dConfig, +}; #[derive(new, Debug)] struct MatmulTiling2d { @@ -23,136 +25,10 @@ struct MatmulTiling2d { #[derive(new, Debug)] struct MatmulTiling2dEagerKernel { config: Tiling2dConfig, + assumption: Tiling2DAssumption, _runtime: PhantomData, } -struct MatmulTiling2dShader { - variables: BinaryOperator, - block_size: usize, -} - -impl MatmulTiling2dShader { - fn expand(self, scope: &mut Scope) { - // Define out global variables. - let local_idx = Variable::LocalInvocationIndex; - let batch = Variable::GlobalInvocationIdZ; - let rank = Variable::Rank; - let block_size: Variable = self.block_size.into(); - - // Extract tensor variables. - let lhs = self.variables.lhs; - let rhs = self.variables.rhs; - let out = self.variables.out; - - // Define where we have to work on the current matrix. - let tmp_index = scope.create_local(Elem::UInt); - let batch_dims = scope.create_local(Elem::UInt); - let row = scope.create_local(Elem::UInt); - let col = scope.create_local(Elem::UInt); - - // Row position. - gpu!(scope, tmp_index = local_idx / block_size); - gpu!(scope, row = block_size * Variable::WorkgroupIdX); - gpu!(scope, row = row + tmp_index); - - // Col position. - gpu!(scope, tmp_index = local_idx % block_size); - gpu!(scope, col = block_size * Variable::WorkgroupIdY); - gpu!(scope, col = col + tmp_index); - - // Batch position. - gpu!(scope, batch_dims = rank - 2u32); - - // Define the matrix size. - let n_rows = scope.create_local(Elem::UInt); - let n_cols = scope.create_local(Elem::UInt); - let k = scope.create_local(Elem::UInt); - - // Number of rows. - gpu!(scope, n_rows = shape(out, batch_dims)); - - // Number of cols. - gpu!(scope, tmp_index = batch_dims + 1u32); - gpu!(scope, n_cols = shape(out, tmp_index)); - - // The dimension that is going to be squashed. - gpu!(scope, k = shape(lhs, tmp_index)); - - // Check if there is some work to be done. - let should_stop = scope.create_local(Elem::Bool); - gpu!(scope, should_stop = row >= n_rows); - gpu!(scope, if (should_stop).then(|scope| { - scope.register(Branch::Return); - })); - - gpu!(scope, should_stop = col >= n_cols); - gpu!(scope, if (should_stop).then(|scope| { - scope.register(Branch::Return); - })); - - // Calculate the batch offset. - let offset_lhs = scope.zero(Elem::UInt); - let offset_rhs = scope.zero(Elem::UInt); - let offset_output = scope.create_local(Elem::UInt); - - // Batch offset for the output. - gpu!(scope, offset_output = n_rows * n_cols); - gpu!(scope, offset_output = offset_output * batch); - - // Batch offset for the lhs & rhs matrices. - IndexOffsetGlobalWithLayout { - tensors: vec![lhs, rhs], - indexes: vec![offset_lhs, offset_rhs], - layout: out, - index_ref: offset_output, - dim_start: 0u32.into(), - dim_end: batch_dims, - } - .expand(scope); - - // Calculate the dot product (row X col). - let sum = scope.create_local(out.item()); - - // Initialize the sum to zero. - let zero: Variable = 0f32.into(); - gpu!(scope, sum = zero); - - // Loop over the k dimension. - gpu!( - scope, - range(0u32, k).for_each(|i, scope| { - let lhs_index = scope.create_local(Elem::UInt); - let rhs_index = scope.create_local(Elem::UInt); - - let lhs_value = scope.create_local(lhs.item()); - let rhs_value = scope.create_local(rhs.item()); - let out_value = scope.create_local(out.item()); - - gpu!(scope, lhs_index = row * k); - gpu!(scope, lhs_index = lhs_index + i); - gpu!(scope, lhs_index = lhs_index + offset_lhs); - - gpu!(scope, rhs_index = i * n_cols); - gpu!(scope, rhs_index = rhs_index + col); - gpu!(scope, rhs_index = rhs_index + offset_rhs); - - gpu!(scope, lhs_value = lhs[lhs_index]); - gpu!(scope, rhs_value = rhs[rhs_index]); - - gpu!(scope, out_value = lhs_value * rhs_value); - gpu!(scope, sum += out_value); - }) - ); - - let out_index = scope.create_local(Elem::UInt); - - gpu!(scope, out_index = row * n_cols); - gpu!(scope, out_index += col); - gpu!(scope, out_index += offset_output); - gpu!(scope, out[out_index] = sum); - } -} - impl DynamicKernelSource for MatmulTiling2dEagerKernel { fn source(&self) -> SourceTemplate { let mut scope = gpu::Scope::root(); @@ -164,7 +40,9 @@ impl DynamicKernelSource for MatmulTiling2dEagerKernel { MatmulTiling2dShader { variables: gpu::BinaryOperator { lhs, rhs, out }, - block_size: self.config.grid_x, // TODO + config: self.config.clone(), + assumption: self.assumption.clone(), + unroll: false, } .expand(&mut scope); @@ -198,9 +76,10 @@ impl DynamicKernelSource for MatmulTiling2dEagerKernel { fn id(&self) -> String { format!( - "{:?}config={:?}", + "{:?}config={:?}assumption={:?}", core::any::TypeId::of::(), self.config, + self.assumption ) } } @@ -213,7 +92,9 @@ pub fn matmul_tiling_2d( out: JitTensor, config: Tiling2dConfig, ) -> JitTensor { - let kernel = MatmulTiling2dEagerKernel::::new(config.clone()); + let assumption = check_assumption(&lhs.shape, &rhs.shape, &config); + + let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), assumption); let client = lhs.client.clone(); let lhs = match lhs.batch_swapped_with_row_col() { @@ -237,3 +118,75 @@ pub fn matmul_tiling_2d( out } + +/// Matrix multiplication using tiling 2d algorithm with padding needed +pub fn matmul_tiling_2d_padded( + lhs: JitTensor, + rhs: JitTensor, + out: JitTensor, + config: Tiling2dConfig, +) -> JitTensor { + let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), Tiling2DAssumption::Round); + let client = lhs.client.clone(); + + // A tensor may need to be padded, in which case it will implicitly become contiguous + // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. + // If batches were swapped among themselves, or if the last two dims are transposed, the underlying + // kernel handles it without needing to turn it into contiguous. + let round_lhs = pad_round::(lhs, config.block_size_m, config.block_size_k); + let lhs = match round_lhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_lhs.into_tensor(), + }; + let round_rhs = pad_round::(rhs, config.block_size_k, config.block_size_n); + let rhs = match round_rhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_rhs.into_tensor(), + }; + + let rounded_output_shape = shape_out(&lhs, &rhs); + + let num_elems = rounded_output_shape.num_elements(); + let buffer = client.empty(num_elems * core::mem::size_of::()); + let rounded_output = JitTensor::new( + rhs.client.clone(), + rhs.device.clone(), + rounded_output_shape.clone(), + buffer, + ); + + Execution::start(kernel, client) + .inputs(&[ + EagerHandle::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + ]) + .outputs(&[EagerHandle::new( + &rounded_output.handle, + &rounded_output.strides, + &rounded_output.shape.dims, + )]) + .execute(WorkgroupLaunch::Custom(tiling2d_launch_options( + &rounded_output.shape, + config, + ))); + + crop(rounded_output, out) +} + +fn check_assumption( + lhs_shape: &Shape, + rhs_shape: &Shape, + config: &Tiling2dConfig, +) -> Tiling2DAssumption { + let m_divisible = lhs_shape.dims[D - 2] % config.block_size_m == 0; + let k_divisible = lhs_shape.dims[D - 1] % config.block_size_k == 0; + let n_divisible = rhs_shape.dims[D - 1] % config.block_size_n == 0; + match m_divisible && k_divisible && n_divisible { + true => Tiling2DAssumption::Round, + false => Tiling2DAssumption::None, + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs similarity index 50% rename from crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs rename to crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs index 8cb59fa193..ef04abe05a 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_padded.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs @@ -1,42 +1,16 @@ -use burn_tensor::Element; - -use crate::{ - codegen::{ - dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, - Execution, InputInfo, OutputInfo, WorkgroupLaunch, - }, - element::JitElement, - gpu::{gpu, BinaryOperator, Branch, Elem, Item, Scope, Synchronization, Variable}, - kernel::{into_contiguous, DynamicKernelSource, SourceTemplate}, - tensor::JitTensor, - Runtime, -}; -use std::marker::PhantomData; - -use super::{ - padding::{crop, pad_round, PaddingOutput}, - shape_out, tiling2d_launch_options, Tiling2dConfig, -}; - -#[derive(new, Debug)] -struct MatmulTiling2dPadded { - _elem: PhantomData, -} +use crate::gpu::{gpu, BinaryOperator, Elem, Item, Scope, Synchronization, Variable}; -#[derive(new, Debug)] -struct MatmulTiling2dPaddedEagerKernel { - config: Tiling2dConfig, - _runtime: PhantomData, -} +use super::{Tiling2DAssumption, Tiling2dConfig}; -struct MatmulTiling2dPaddedShader { - variables: BinaryOperator, - config: Tiling2dConfig, - unroll: bool, +pub(crate) struct MatmulTiling2dShader { + pub variables: BinaryOperator, + pub config: Tiling2dConfig, + pub assumption: Tiling2DAssumption, + pub unroll: bool, } -impl MatmulTiling2dPaddedShader { - fn expand(self, scope: &mut Scope) { +impl MatmulTiling2dShader { + pub(crate) fn expand(self, scope: &mut Scope) { // Inputs let lhs = self.variables.lhs; let rhs = self.variables.rhs; @@ -93,7 +67,7 @@ impl MatmulTiling2dPaddedShader { gpu!(scope, skip_col = workgroup_id_y); gpu!(scope, skip_col *= block_size_n); - // Invocation offset + // Position of the first element of the thread, relative to the block let thread_row = scope.create_local(Elem::UInt); gpu!(scope, thread_row = local_idx / n_threads_per_row); gpu!(scope, thread_row *= tile_size_m); @@ -101,7 +75,7 @@ impl MatmulTiling2dPaddedShader { gpu!(scope, thread_col = local_idx % n_threads_per_row); gpu!(scope, thread_col *= tile_size_n); - // Row and col + // Position of the first element of the thread, in absolute (in one batch) let row = scope.create_local(Elem::UInt); let col = scope.create_local(Elem::UInt); gpu!(scope, row = skip_row + thread_row); @@ -151,6 +125,8 @@ impl MatmulTiling2dPaddedShader { ); let elem = lhs.item().elem(); + + // Registers used in the compute pass let results = scope.create_local_array(elem, results_size); let register_m = scope.create_local(Item::Vec4(elem)); let register_n = scope.create_local(Item::Vec4(elem)); @@ -164,7 +140,22 @@ impl MatmulTiling2dPaddedShader { ); let n_loops = scope.create_local(Elem::UInt); - gpu!(scope, n_loops = dim_k / block_size_k); // assumes padding, otherwise ceil + match self.assumption { + Tiling2DAssumption::Round => { + gpu!(scope, n_loops = dim_k / block_size_k); + } + Tiling2DAssumption::None => { + let dim_k_float = scope.create_local(elem); + let block_size_k_float = scope.create_local(elem); + let n_loops_float = scope.create_local(elem); + gpu!(scope, dim_k_float = dim_k); + gpu!(scope, block_size_k_float = block_size_k); + gpu!(scope, n_loops_float = dim_k_float / block_size_k_float); + gpu!(scope, n_loops_float = ceil(n_loops_float)); + gpu!(scope, n_loops = n_loops_float); + } + } + gpu!( scope, range(0u32, n_loops).for_each(|i, scope| { @@ -172,33 +163,72 @@ impl MatmulTiling2dPaddedShader { let k = scope.create_local(Elem::UInt); gpu!(scope, k = i * block_size_k); - // LHS - self.load_shared_memory( - scope, - k, - thread_col, - thread_row, - lhs_stride_col, - lhs_stride_row, - lhs, - offset_lhs, - shared_lhs, - true, - ); - - // RHS - self.load_shared_memory( - scope, - k, - thread_row, - thread_col, - rhs_stride_row, - rhs_stride_col, - rhs, - offset_rhs, - shared_rhs, - false, - ); + match self.assumption { + Tiling2DAssumption::Round => { + // LHS + self.load_shared_memory( + scope, + k, + thread_col, + thread_row, + lhs_stride_col, + lhs_stride_row, + lhs, + offset_lhs, + shared_lhs, + true, + ); + + // RHS + self.load_shared_memory( + scope, + k, + thread_row, + thread_col, + rhs_stride_row, + rhs_stride_col, + rhs, + offset_rhs, + shared_rhs, + false, + ); + } + Tiling2DAssumption::None => { + // LHS + self.load_shared_memory_with_bound_check( + scope, + k, + dim_k, + thread_col, + thread_row, + lhs_stride_col, + lhs_stride_row, + dim_m, + row, + lhs, + offset_lhs, + shared_lhs, + true, + ); + + // RHS + self.load_shared_memory_with_bound_check( + scope, + k, + dim_k, + thread_row, + thread_col, + rhs_stride_row, + rhs_stride_col, + dim_n, + col, + rhs, + offset_rhs, + shared_rhs, + false, + ); + } + } scope.register(Synchronization::WorkgroupBarrier); @@ -216,6 +246,8 @@ impl MatmulTiling2dPaddedShader { scope, row, col, + dim_m, + dim_n, out_stride_row, out_stride_col, results, @@ -224,6 +256,127 @@ impl MatmulTiling2dPaddedShader { ); } + fn load_shared_memory_with_bound_check( + &self, + scope: &mut Scope, + k: Variable, + dim_k: Variable, + thread_idx_1: Variable, + thread_idx_2: Variable, + stride_1: Variable, + stride_2: Variable, + dim: Variable, + pos_in_dim: Variable, + input: Variable, + input_offset: Variable, + shared_memory: Variable, + is_lhs: bool, + ) { + // How close is the thread to the end of the matrix. + // If < 4 then it is an edge case + + let remain = scope.create_local(Elem::UInt); + gpu!(scope, remain = dim - pos_in_dim); + + let block_size_k: Variable = self.config.block_size_k.into(); + let block_size_n: Variable = self.config.block_size_n.into(); + let elem = input.item().elem(); + + gpu!( + scope, + range(0_u32, 4u32, self.unroll).for_each(|j, scope| { + let current = scope.create_local(Elem::UInt); + gpu!(scope, current = thread_idx_1 + j); + + let aligned_with_shared_memory = scope.create_local(Elem::Bool); + gpu!(scope, aligned_with_shared_memory = current < block_size_k); + + // To avoid overwriting following row in shared memory + gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ + + // Position in shared memory + let sm_position = scope.create_local(Elem::UInt); + if is_lhs { + gpu!(scope, sm_position = thread_idx_2 / 4u32); + gpu!(scope, sm_position *= block_size_k); + gpu!(scope, sm_position += current); + } else { + gpu!(scope, sm_position = current * block_size_n); + gpu!(scope, sm_position += thread_idx_2); + gpu!(scope, sm_position = sm_position / 4u32); + } + + // To pad with zeros if outside lhs + let within_input = scope.create_local(Elem::Bool); + let current_with_k = scope.create_local(Elem::UInt); + let remain_at_least_1 = scope.create_local(Elem::Bool); + let read_condition = scope.create_local(Elem::Bool); + gpu!(scope, current_with_k = current + k); + gpu!(scope, within_input = current_with_k < dim_k); + gpu!(scope, remain_at_least_1 = remain >= 1u32); + gpu!(scope, read_condition = within_input && remain_at_least_1); + + gpu!(scope, if(read_condition).then(|scope| { + let position_0 = scope.create_local(Elem::UInt); + gpu!(scope, position_0 = k + current); + gpu!(scope, position_0 *= stride_1); + let tmp = scope.create_local(Elem::UInt); + gpu!(scope, tmp = thread_idx_2 * stride_2); + gpu!(scope, position_0 += tmp); + gpu!(scope, position_0 += input_offset); + let position_1 = scope.create_local(Elem::UInt); + let position_2 = scope.create_local(Elem::UInt); + let position_3 = scope.create_local(Elem::UInt); + gpu!(scope, position_1 = position_0 + stride_2); + gpu!(scope, position_2 = position_1 + stride_2); + gpu!(scope, position_3 = position_2 + stride_2); + + let val_0 = scope.zero(elem); + let val_1 = scope.zero(elem); + let val_2 = scope.zero(elem); + let val_3 = scope.zero(elem); + + let remain_n = scope.create_local(Elem::Bool); + gpu!(scope, remain_n = remain >= 4u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + gpu!(scope, val_2 = input[position_2]); + gpu!(scope, val_3 = input[position_3]); + }).else(|scope|{ + gpu!(scope, remain_n = remain == 3u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + gpu!(scope, val_2 = input[position_2]); + }).else(|scope|{ + gpu!(scope, remain_n = remain == 2u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + }).else(|scope|{ + gpu!(scope, remain_n = remain == 1u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + })); + })); + })); + })); + + let val_vec4 = scope.create_local(shared_memory.item()); + gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); + gpu!(scope, shared_memory[sm_position] = val_vec4); + }).else(|scope|{ + let val_0 = scope.zero(elem); + let val_vec4 = scope.create_local(shared_memory.item()); + gpu!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0)); + gpu!(scope, shared_memory[sm_position] = val_vec4); + })); + })); + }) + ); + } + fn load_shared_memory( &self, scope: &mut Scope, @@ -244,29 +397,28 @@ impl MatmulTiling2dPaddedShader { gpu!( scope, range(0_u32, 4u32, self.unroll).for_each(|j, scope| { - let current_col = scope.create_local(Elem::UInt); - gpu!(scope, current_col = thread_idx_1 + j); + let current = scope.create_local(Elem::UInt); + gpu!(scope, current = thread_idx_1 + j); let aligned_with_shared_memory = scope.create_local(Elem::Bool); - gpu!( - scope, - aligned_with_shared_memory = current_col < block_size_k - ); + gpu!(scope, aligned_with_shared_memory = current < block_size_k); + // To avoid overwriting following row in shared memory gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ + let sm_position = scope.create_local(Elem::UInt); if is_lhs { gpu!(scope, sm_position = thread_idx_2 / 4u32); gpu!(scope, sm_position *= block_size_k); - gpu!(scope, sm_position += current_col); + gpu!(scope, sm_position += current); } else { - gpu!(scope, sm_position = current_col * block_size_n); + gpu!(scope, sm_position = current * block_size_n); gpu!(scope, sm_position += thread_idx_2); gpu!(scope, sm_position = sm_position / 4u32); } let position_0 = scope.create_local(Elem::UInt); - gpu!(scope, position_0 = k + current_col); + gpu!(scope, position_0 = k + current); gpu!(scope, position_0 *= stride_1); let tmp = scope.create_local(Elem::UInt); gpu!(scope, tmp = thread_idx_2 * stride_2); @@ -377,14 +529,14 @@ impl MatmulTiling2dPaddedShader { scope: &mut Scope, row: Variable, col: Variable, + dim_m: Variable, + dim_n: Variable, out_stride_row: Variable, out_stride_col: Variable, results: Variable, offset_output: Variable, out: Variable, ) { - let elem = results.item().elem(); - gpu!( scope, range(0u32, self.config.tile_size_m as u32, self.unroll).for_each( @@ -393,31 +545,49 @@ impl MatmulTiling2dPaddedShader { scope, range(0u32, self.config.tile_size_n as u32, self.unroll).for_each( |res_idx_n, scope| { - let results_position = scope.create_local(Elem::UInt); - gpu!( - scope, - results_position = res_idx_m * self.config.tile_size_n - ); - gpu!(scope, results_position += res_idx_n); - - let result = scope.create_local(elem); - gpu!(scope, result = results[results_position]); - - let output_position = scope.create_local(Elem::UInt); - let output_position_tmp1 = scope.create_local(Elem::UInt); - let output_position_tmp2 = scope.create_local(Elem::UInt); - gpu!(scope, output_position_tmp1 = row + res_idx_m); - gpu!(scope, output_position_tmp1 *= out_stride_row); - gpu!(scope, output_position_tmp2 = col + res_idx_n); - gpu!(scope, output_position_tmp2 *= out_stride_col); - gpu!( - scope, - output_position = output_position_tmp1 + output_position_tmp2 - ); - gpu!(scope, output_position += offset_output); - - // gpu!(scope, out[output_position] = tmp); - gpu!(scope, out[output_position] = result); + let row_index = scope.create_local(Elem::UInt); + let col_index = scope.create_local(Elem::UInt); + gpu!(scope, row_index = row + res_idx_m); + gpu!(scope, col_index = col + res_idx_n); + + match self.assumption { + Tiling2DAssumption::Round => self.write_inner( + scope, + res_idx_m, + res_idx_n, + row_index, + col_index, + out_stride_row, + out_stride_col, + results, + offset_output, + out, + ), + Tiling2DAssumption::None => { + let within_output = scope.create_local(Elem::Bool); + let within_output_tmp = scope.create_local(Elem::Bool); + gpu!(scope, within_output = row_index < dim_m); + gpu!(scope, within_output_tmp = col_index < dim_n); + gpu!( + scope, + within_output = within_output && within_output_tmp + ); + gpu!(scope, if(within_output).then(|scope|{ + self.write_inner( + scope, + res_idx_m, + res_idx_n, + row_index, + col_index, + out_stride_row, + out_stride_col, + results, + offset_output, + out, + ); + })); + } + } } ) ); @@ -425,116 +595,37 @@ impl MatmulTiling2dPaddedShader { ) ); } -} -impl DynamicKernelSource for MatmulTiling2dPaddedEagerKernel { - fn source(&self) -> SourceTemplate { - let mut scope = gpu::Scope::root(); - let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into()); - let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into()); - let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into()); + fn write_inner( + &self, + scope: &mut Scope, + res_idx_m: Variable, + res_idx_n: Variable, + row_index: Variable, + col_index: Variable, + out_stride_row: Variable, + out_stride_col: Variable, + results: Variable, + offset_output: Variable, + out: Variable, + ) { + let elem = results.item().elem(); + let results_position = scope.create_local(Elem::UInt); + gpu!( + scope, + results_position = res_idx_m * self.config.tile_size_n + ); + gpu!(scope, results_position += res_idx_n); - scope.write_global_custom(out); + let result = scope.create_local(elem); + gpu!(scope, result = results[results_position]); - MatmulTiling2dPaddedShader { - variables: gpu::BinaryOperator { lhs, rhs, out }, - config: self.config.clone(), - unroll: true, - } - .expand(&mut scope); - - let lhs = InputInfo::Array { - item: gpu::Elem::Float.into(), - visibility: gpu::Visibility::Read, - }; - let rhs = InputInfo::Array { - item: gpu::Elem::Float.into(), - visibility: gpu::Visibility::Read, - }; - let out = OutputInfo::Array { - item: gpu::Elem::Float.into(), - }; - - let info = CompilationInfo { - inputs: vec![lhs, rhs], - outputs: vec![out], - scope, - }; - - let settings = CompilationSettings::default().workgroup_size(gpu::WorkgroupSize::new( - self.config.grid_x as u32, - self.config.grid_y as u32, - 1, - )); - let shader = Compilation::new(info).compile(settings); - let shader = ::compile(shader); - SourceTemplate::new(shader.to_string()) - } + let output_position = scope.create_local(Elem::UInt); + gpu!(scope, row_index *= out_stride_row); + gpu!(scope, col_index *= out_stride_col); + gpu!(scope, output_position = row_index + col_index); + gpu!(scope, output_position += offset_output); - fn id(&self) -> String { - format!( - "{:?}config={:?}", - core::any::TypeId::of::(), - self.config, - ) + gpu!(scope, out[output_position] = result); } } - -/// Matrix multiplication using tiling 2d algorithm with -/// vec4 primitive on both lhs and rhs, with no padding needed -pub fn matmul_tiling_2d_padded( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, - config: Tiling2dConfig, -) -> JitTensor { - let kernel = MatmulTiling2dPaddedEagerKernel::::new(config.clone()); - let client = lhs.client.clone(); - - // A tensor may need to be padded, in which case it will implicitly become contiguous - // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. - // If batches were swapped among themselves, or if the last two dims are transposed, the underlying - // kernel handles it without needing to turn it into contiguous. - let round_lhs = pad_round::(lhs, config.block_size_m, config.block_size_k); - let lhs = match round_lhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_lhs.into_tensor(), - }; - let round_rhs = pad_round::(rhs, config.block_size_k, config.block_size_n); - let rhs = match round_rhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_rhs.into_tensor(), - }; - - let rounded_output_shape = shape_out(&lhs, &rhs); - - let num_elems = rounded_output_shape.num_elements(); - let buffer = client.empty(num_elems * core::mem::size_of::()); - let rounded_output = JitTensor::new( - rhs.client.clone(), - rhs.device.clone(), - rounded_output_shape.clone(), - buffer, - ); - - Execution::start(kernel, client) - .inputs(&[ - EagerHandle::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), - ]) - .outputs(&[EagerHandle::new( - &rounded_output.handle, - &rounded_output.strides, - &rounded_output.shape.dims, - )]) - .execute(WorkgroupLaunch::Custom(tiling2d_launch_options( - &rounded_output.shape, - config, - ))); - - crop(rounded_output, out) -} diff --git a/crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul copy.wgsl b/crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul copy.wgsl deleted file mode 100644 index 0b7a0f31be..0000000000 --- a/crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul copy.wgsl +++ /dev/null @@ -1,250 +0,0 @@ -@group(0) -@binding(0) -var input_0_global: array; - -@group(0) -@binding(1) -var input_1_global: array; - -@group(0) -@binding(2) -var output_0_global: array; - -@group(0) -@binding(3) -var info: array; - -var shared_memory_0: array, 512>; - -var shared_memory_1: array, 512>; - -const WORKGROUP_SIZE_X = 16u; -const WORKGROUP_SIZE_Y = 16u; -const WORKGROUP_SIZE_Z = 1u; - -@compute -@workgroup_size(16, 16, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_index) local_idx: u32, - @builtin(workgroup_id) workgroup_id: vec3, -) { - var a_0_0: array; - - let rank: u32 = info[0]; - let rank_2: u32 = rank * 2u; - var l_0_0: u32; - var l_0_1: u32; - var l_0_2: u32; - var l_0_3: u32; - var l_0_4: u32; - var l_0_5: u32; - var l_0_6: u32; - var l_0_7: u32; - var l_0_8: u32; - var l_0_9: u32; - var l_0_10: u32; - var l_0_11: u32; - var l_0_12: u32; - var l_0_13: u32; - var l_0_14: u32; - var l_0_15: u32; - var l_0_16: u32; - var l_0_17: u32; - var l_0_18: u32; - var l_0_19: u32; - var l_0_20: vec4; - var l_0_21: vec4; - var l_0_22: u32; - l_0_0 = rank - 1u; - l_0_1 = info[(0u * rank_2) + rank + l_0_0 + 1u]; - l_0_2 = info[(1u * rank_2) + rank + l_0_0 + 1u]; - l_0_3 = info[(1u * rank_2) + rank + rank + 1u]; - l_0_4 = info[(0u * rank_2) + l_0_0 + 1u]; - l_0_5 = info[(0u * rank_2) + rank + 1u]; - l_0_6 = info[(1u * rank_2) + l_0_0 + 1u]; - l_0_7 = info[(1u * rank_2) + rank + 1u]; - l_0_8 = info[(2u * rank_2) + l_0_0 + 1u]; - l_0_9 = info[(2u * rank_2) + rank + 1u]; - l_0_10 = u32(workgroup_id.x); - l_0_10 = l_0_10 * 64u; - l_0_11 = u32(workgroup_id.y); - l_0_11 = l_0_11 * 64u; - l_0_12 = local_idx / 16u; - l_0_12 = l_0_12 * 4u; - l_0_13 = local_idx % 16u; - l_0_13 = l_0_13 * 4u; - l_0_14 = l_0_10 + l_0_12; - l_0_15 = l_0_11 + l_0_13; - l_0_16 = l_0_10 * l_0_4; - l_0_17 = l_0_11 * l_0_7; - l_0_18 = l_0_1 * l_0_3; - l_0_18 = l_0_18 * global_id.z; - l_0_19 = rank - 2u; - l_0_16 = u32(0u); - l_0_17 = u32(0u); - - for (var l_1_0: u32 = 0u; l_1_0 < l_0_19; l_1_0++) { - var l_1_1: u32; - var l_1_2: u32; - var l_1_3: u32; - var l_1_4: u32; - var l_1_5: u32; - var l_1_6: u32; - var l_1_7: u32; - var l_1_8: u32; - l_1_1 = info[(2u * rank_2) + l_1_0 + 1u]; - l_1_2 = l_0_18 * 1u; - l_1_2 = l_1_2 / l_1_1; - l_1_3 = info[(0u * rank_2) + l_1_0 + 1u]; - l_1_4 = info[(0u * rank_2) + rank + l_1_0 + 1u]; - l_1_5 = l_1_2 % l_1_4; - l_1_5 = l_1_5 * l_1_3; - l_0_16 = l_0_16 + l_1_5; - l_1_6 = info[(1u * rank_2) + l_1_0 + 1u]; - l_1_7 = info[(1u * rank_2) + rank + l_1_0 + 1u]; - l_1_8 = l_1_2 % l_1_7; - l_1_8 = l_1_8 * l_1_6; - l_0_17 = l_0_17 + l_1_8; - } - l_0_16 = l_0_16 / 1u; - l_0_17 = l_0_17 / 1u; - l_0_22 = l_0_2 / 32u; - - for (var l_1_0: u32 = 0u; l_1_0 < l_0_22; l_1_0++) { - var l_1_1: u32; - l_1_1 = l_1_0 * 32u; - - for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { - var l_2_1: u32; - var l_2_2: bool; - l_2_1 = l_0_13 + l_2_0; - l_2_2 = l_2_1 < 32u; - if l_2_2 { - var l_3_0: u32; - var l_3_1: u32; - var l_3_2: u32; - var l_3_3: u32; - var l_3_4: u32; - var l_3_5: u32; - var l_3_6: f32; - var l_3_7: f32; - var l_3_8: f32; - var l_3_9: f32; - var l_3_10: vec4; - l_3_0 = l_0_12 / 4u; - l_3_0 = l_3_0 * 32u; - l_3_0 = l_3_0 + l_2_1; - l_3_1 = l_1_1 + l_2_1; - l_3_1 = l_3_1 * l_0_5; - l_3_2 = l_0_12 * l_0_4; - l_3_1 = l_3_1 + l_3_2; - l_3_1 = l_3_1 + l_0_16; - l_3_3 = l_3_1 + l_0_4; - l_3_4 = l_3_3 + l_0_4; - l_3_5 = l_3_4 + l_0_4; - l_3_6 = f32(input_0_global[l_3_1]); - l_3_7 = f32(input_0_global[l_3_3]); - l_3_8 = f32(input_0_global[l_3_4]); - l_3_9 = f32(input_0_global[l_3_5]); - l_3_10 = vec4(l_3_6, l_3_7, l_3_8, l_3_9); - shared_memory_0[l_3_0] = vec4(l_3_10); - } else { - break; - } - } - - for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { - var l_2_1: u32; - var l_2_2: bool; - l_2_1 = l_0_12 + l_2_0; - l_2_2 = l_2_1 < 32u; - if l_2_2 { - var l_3_0: u32; - var l_3_1: u32; - var l_3_2: u32; - var l_3_3: u32; - var l_3_4: u32; - var l_3_5: u32; - var l_3_6: f32; - var l_3_7: f32; - var l_3_8: f32; - var l_3_9: f32; - var l_3_10: vec4; - l_3_0 = l_2_1 * 64u; - l_3_0 = l_3_0 + l_0_13; - l_3_0 = l_3_0 / 4u; - l_3_1 = l_1_1 + l_2_1; - l_3_1 = l_3_1 * l_0_6; - l_3_2 = l_0_13 * l_0_7; - l_3_1 = l_3_1 + l_3_2; - l_3_1 = l_3_1 + l_0_17; - l_3_3 = l_3_1 + l_0_7; - l_3_4 = l_3_3 + l_0_7; - l_3_5 = l_3_4 + l_0_7; - l_3_6 = f32(input_1_global[l_3_1]); - l_3_7 = f32(input_1_global[l_3_3]); - l_3_8 = f32(input_1_global[l_3_4]); - l_3_9 = f32(input_1_global[l_3_5]); - l_3_10 = vec4(l_3_6, l_3_7, l_3_8, l_3_9); - shared_memory_1[l_3_0] = vec4(l_3_10); - } else { - break; - } - } - workgroupBarrier(); - - for (var l_2_0: u32 = 0u; l_2_0 < 32u; l_2_0++) { - var l_2_1: u32; - var l_2_2: u32; - l_2_1 = l_0_12 / 4u; - l_2_1 = l_2_1 * 32u; - l_2_1 = l_2_1 + l_2_0; - l_0_20 = vec4(shared_memory_0[l_2_1]); - l_2_2 = l_2_0 * 64u; - l_2_2 = l_2_2 + l_0_13; - l_2_2 = l_2_2 / 4u; - l_0_21 = vec4(shared_memory_1[l_2_2]); - - for (var l_3_0: u32 = 0u; l_3_0 < 4u; l_3_0++) { - for (var l_4_0: u32 = 0u; l_4_0 < 4u; l_4_0++) { - var l_4_1: f32; - var l_4_2: f32; - var l_4_3: f32; - var l_4_4: u32; - var l_4_5: f32; - var l_4_6: f32; - l_4_1 = f32(l_0_20[l_3_0]); - l_4_2 = f32(l_0_21[l_4_0]); - l_4_3 = l_4_1 * l_4_2; - l_4_4 = l_3_0 * 4u; - l_4_4 = l_4_4 + l_4_0; - l_4_5 = f32(a_0_0[l_4_4]); - l_4_6 = l_4_5 + l_4_3; - a_0_0[l_4_4] = f32(l_4_6); - } - } - } - workgroupBarrier(); - } - - for (var l_1_0: u32 = 0u; l_1_0 < 4u; l_1_0++) { - for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { - var l_2_1: u32; - var l_2_2: f32; - var l_2_3: u32; - var l_2_4: u32; - var l_2_5: u32; - l_2_1 = l_1_0 * 4u; - l_2_1 = l_2_1 + l_2_0; - l_2_2 = f32(a_0_0[l_2_1]); - l_2_4 = l_0_14 + l_1_0; - l_2_4 = l_2_4 * l_0_8; - l_2_5 = l_0_15 + l_2_0; - l_2_5 = l_2_5 * l_0_9; - l_2_3 = l_2_4 + l_2_5; - l_2_3 = l_2_3 + l_0_18; - output_0_global[l_2_3] = f32(l_2_2); - } - } -} \ No newline at end of file diff --git a/crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul.wgsl b/crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul.wgsl deleted file mode 100644 index 1b4021f142..0000000000 --- a/crates/burn-jit/src/template/matmul/blocktiling_2d/generated_matmul.wgsl +++ /dev/null @@ -1,211 +0,0 @@ -@group(0) -@binding(0) -var input_0_global: array; - -@group(0) -@binding(1) -var input_1_global: array; - -@group(0) -@binding(2) -var output_0_global: array; - -@group(0) -@binding(3) -var info: array; - -var shared_memory_0: array, 512>; - -var shared_memory_1: array, 512>; - -const WORKGROUP_SIZE_X = 16u; -const WORKGROUP_SIZE_Y = 16u; -const WORKGROUP_SIZE_Z = 1u; - -@compute -@workgroup_size(16, 16, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_index) local_idx: u32, - @builtin(workgroup_id) workgroup_id: vec3, -) { - var results: array; - - let rank: u32 = info[0]; - let rank_2: u32 = rank * 2u; - var l_0_0: u32; - var M: u32; - var K: u32; - var N: u32; - var lhs_stride_row: u32; - var lhs_stride_col: u32; - var rhs_stride_row: u32; - var rhs_stride_col: u32; - var out_stride_row: u32; - var out_stride_col: u32; - var skip_row: u32; - var skip_col: u32; - var thread_row: u32; - var thread_col: u32; - var row: u32; - var col: u32; - var offset_lhs: u32; - var offset_rhs: u32; - var offset_output: u32; - var l_0_19: u32; - var register_m: vec4; - var register_n: vec4; - var n_loops: u32; - l_0_0 = rank - 1u; - M = info[(0u * rank_2) + rank + l_0_0 + 1u]; - K = info[(1u * rank_2) + rank + l_0_0 + 1u]; - N = info[(1u * rank_2) + rank + rank + 1u]; - lhs_stride_row = info[(0u * rank_2) + l_0_0 + 1u]; - lhs_stride_col = info[(0u * rank_2) + rank + 1u]; - rhs_stride_row = info[(1u * rank_2) + l_0_0 + 1u]; - rhs_stride_col = info[(1u * rank_2) + rank + 1u]; - out_stride_row = info[(2u * rank_2) + l_0_0 + 1u]; - out_stride_col = info[(2u * rank_2) + rank + 1u]; - skip_row = u32(workgroup_id.x) * 64u; - skip_col = u32(workgroup_id.y) * 64u; - thread_row = (local_idx / 16u)*4u; - thread_col = (local_idx % 16u)*4u; - row = skip_row + thread_row; - col = skip_col + thread_col; - offset_lhs = skip_row * lhs_stride_row; - offset_rhs = skip_col * rhs_stride_col; - offset_output = M * N; - offset_output = offset_output * global_id.z; - l_0_19 = rank - 2u; - - for (var l_1_0: u32 = 0u; l_1_0 < l_0_19; l_1_0++) { - var l_1_1: u32; - var l_1_2: u32; - var l_1_3: u32; - var l_1_4: u32; - var l_1_5: u32; - var l_1_6: u32; - var l_1_7: u32; - var l_1_8: u32; - l_1_1 = info[(0u * rank_2) + l_1_0 + 1u]; - l_1_2 = info[(1u * rank_2) + l_1_0 + 1u]; - l_1_3 = info[(2u * rank_2) + l_1_0 + 1u]; - l_1_4 = info[(0u * rank_2) + rank + l_1_0 + 1u]; - l_1_5 = info[(1u * rank_2) + rank + l_1_0 + 1u]; - l_1_6 = offset_output / l_1_3; - l_1_7 = l_1_6 % l_1_4; - l_1_7 = l_1_7 * l_1_1; - offset_lhs = offset_lhs + l_1_7; - l_1_8 = l_1_6 % l_1_5; - l_1_8 = l_1_8 * l_1_2; - offset_rhs = offset_rhs + l_1_8; - } - n_loops = K / 32u; - - for (var i: u32 = 0u; i < n_loops; i++) { - var k: u32; - k = i * 32u; - - for (var j: u32 = 0u; j < 4u; j++) { - var current_col: u32; - var l_2_2: bool; - current_col = thread_col + j; - if current_col < 32u{ - var l_3_0: u32; - var l_3_1: u32; - var l_3_2: u32; - var l_3_3: u32; - var l_3_4: u32; - var l_3_5: u32; - var l_3_6: f32; - var l_3_7: f32; - var l_3_8: f32; - var l_3_9: f32; - var l_3_10: vec4; - l_3_0 = (thread_row / 4u) * 32u + current_col; - lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row; - lhs_position1 = lhs_position0 + lhs_stride_row; - lhs_position2 = lhs_position1 + lhs_stride_row; - lhs_position3 = lhs_position2 + lhs_stride_row; - l_3_6 = f32(input_0_global[lhs_position0]); - l_3_7 = f32(input_0_global[lhs_position1]); - l_3_8 = f32(input_0_global[lhs_position2]); - l_3_9 = f32(input_0_global[lhs_position3]); - l_3_10 = vec4(l_3_6, l_3_7, l_3_8, l_3_9); - shared_memory_0[l_3_0] = vec4(l_3_10); - } else { - break; - } - } - - for (var i: u32 = 0u; i < 4u; i++) { - var current_row: u32; - current_row = thread_row + i; - if current_row < 32u { - var l_3_0: u32; - var rhs_position0: u32; - var l_3_2: u32; - var l_3_3: u32; - var l_3_4: u32; - var l_3_5: u32; - var l_3_6: f32; - var l_3_7: f32; - var l_3_8: f32; - var l_3_9: f32; - var l_3_10: vec4; - rhs_sm_position = (current_row * 64u + thread_col) / 4u; - rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col; - rhs_position1 = rhs_position0 + rhs_stride_col; - rhs_position2 = rhs_position1 + rhs_stride_col; - rhs_position3 = rhs_position2 + rhs_stride_col; - l_3_6 = f32(input_1_global[rhs_position0]); - l_3_7 = f32(input_1_global[rhs_position1]); - l_3_8 = f32(input_1_global[rhs_position2]); - l_3_9 = f32(input_1_global[rhs_position3]); - l_3_10 = vec4(l_3_6, l_3_7, l_3_8, l_3_9); - shared_memory_1[rhs_sm_position] = vec4(l_3_10); - } else { - break; - } - } - workgroupBarrier(); - - for (var dot_index: u32 = 0u; dot_index < 32u; dot_index++) { - var lhs_sm_position: u32; - var rhs_sm_position: u32; - lhs_sm_position = (thread_row / 4u)*32u+dot_index; - register_m = vec4(shared_memory_0[lhs_sm_position]); - rhs_sm_position = (dot_index * 64u + thread_col) / 4u; - register_n = vec4(shared_memory_1[rhs_sm_position]); - - for (var res_idx_m: u32 = 0u; res_idx_m < 4u; res_idx_m++) { - for (var res_idx_n: u32 = 0u; res_idx_n < 4u; res_idx_n++) { - var left: f32; - var right: f32; - var multiplied: f32; - var results_position: u32; - var old: f32; - left = f32(register_m[res_idx_m]); - right = f32(register_n[res_idx_n]); - multiplied = left * right; - results_position = res_idx_m * 4u + res_idx_n; - old = f32(results[results_position]); - results[results_position] = f32(old + multiplied); - } - } - } - workgroupBarrier(); - } - - for (var res_idx_m: u32 = 0u; res_idx_m < 4u; res_idx_m++) { - for (var res_idx_n: u32 = 0u; res_idx_n < 4u; res_idx_n++) { - var result_position: u32; - var result: f32; - var output_position: u32; - result_position = res_idx_m * 4u + res_idx_n; - result = f32(results[result_position]); - output_position = (row + res_idx_m) * out_stride_row + (col + res_idx_n) * out_stride_col + offset_output; - output_0_global[output_position] = f32(result); - } - } -} \ No newline at end of file diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index 4268a6f573..d0034e4112 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -187,7 +187,6 @@ where } pub(crate) fn batch_swapped_with_row_col(&self) -> bool { - println!("{:?}", self.strides); for d in 0..D - 2 { let stride = self.strides[d]; if stride < self.strides[D - 2] || stride < self.strides[D - 1] { diff --git a/crates/burn-jit/src/tests/matmul.rs b/crates/burn-jit/src/tests/matmul.rs index 1ca5c5b0e3..7c85478057 100644 --- a/crates/burn-jit/src/tests/matmul.rs +++ b/crates/burn-jit/src/tests/matmul.rs @@ -677,8 +677,6 @@ mod tests { y_jit.into_primitive(), strategy, )); - println!("{}", z_reference); - println!("{}", z); z_reference.into_data().assert_approx_eq(&z.into_data(), 3); } diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index cb1d78e85f..cdfb472f1f 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -432,6 +432,10 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(op.out), }, + gpu::Operator::Ceil(op) => wgsl::Instruction::Ceil { + input: self.compile_variable(op.input), + out: self.compile_variable(op.out), + }, gpu::Operator::Log(op) => wgsl::Instruction::Log { input: self.compile_variable(op.input), out: self.compile_variable(op.out), diff --git a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs index 7235def106..85b38098cd 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs @@ -109,6 +109,10 @@ pub enum Instruction { input: Variable, out: Variable, }, + Ceil { + input: Variable, + out: Variable, + }, Erf { input: Variable, out: Variable, @@ -274,6 +278,9 @@ impl Display for Instruction { Instruction::Sqrt { input, out } => { f.write_fmt(format_args!("{out} = sqrt({input});\n")) } + Instruction::Ceil { input, out } => { + f.write_fmt(format_args!("{out} = ceil({input});\n")) + } Instruction::Log1p { input, out } => { f.write_fmt(format_args!("{out} = log({input} + 1.0);\n")) } diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index 29be374b8a..a6b4e519b2 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -145,7 +145,6 @@ where } let source = kernel.source().complete(); - println!("{}", source); let pipeline = self.compile_source(&source); self.pipelines.insert(kernel_id.clone(), pipeline.clone()); From 41238421f949729306d8d226d4b0e86d16fb6706 Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 14 Mar 2024 13:13:10 -0400 Subject: [PATCH 14/25] clippy --- crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs index ef04abe05a..539218baeb 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs @@ -256,6 +256,7 @@ impl MatmulTiling2dShader { ); } + #[allow(clippy::too_many_arguments)] fn load_shared_memory_with_bound_check( &self, scope: &mut Scope, @@ -377,6 +378,7 @@ impl MatmulTiling2dShader { ); } + #[allow(clippy::too_many_arguments)] fn load_shared_memory( &self, scope: &mut Scope, @@ -448,6 +450,7 @@ impl MatmulTiling2dShader { ); } + #[allow(clippy::too_many_arguments)] fn computation_loop( &self, scope: &mut Scope, @@ -524,6 +527,7 @@ impl MatmulTiling2dShader { ); } + #[allow(clippy::too_many_arguments)] fn write_to_output( &self, scope: &mut Scope, @@ -596,6 +600,7 @@ impl MatmulTiling2dShader { ); } + #[allow(clippy::too_many_arguments)] fn write_inner( &self, scope: &mut Scope, From ed18cfa63fa2013304c1d8d7b729e2120345228e Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 19 Mar 2024 09:03:51 -0400 Subject: [PATCH 15/25] bounds check as bool --- crates/burn-jit/src/kernel/matmul/base.rs | 8 - crates/burn-jit/src/kernel/matmul/tiling2d.rs | 30 ++- .../src/kernel/matmul/tiling2d_shader.rs | 206 +++++++++--------- 3 files changed, 111 insertions(+), 133 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 301968c493..2673aaaa20 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -8,14 +8,6 @@ use super::{ init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, matmul_tiling_2d_padded, }; -#[derive(Debug, Clone)] -pub(crate) enum Tiling2DAssumption { - // Input shapes are divisible by their corresponding block sizes - Round, - // Bounds must be checked - None, -} - #[derive(Debug, Clone)] /// Tiling 2D parameters pub struct Tiling2dConfig { diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index ed43fa0cf9..ddc57bded5 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -14,7 +14,7 @@ use std::marker::PhantomData; use super::{ padding::{crop, pad_round, PaddingOutput}, - shape_out, tiling2d_launch_options, MatmulTiling2dShader, Tiling2DAssumption, Tiling2dConfig, + shape_out, tiling2d_launch_options, MatmulTiling2dShader, Tiling2dConfig, }; #[derive(new, Debug)] @@ -25,7 +25,7 @@ struct MatmulTiling2d { #[derive(new, Debug)] struct MatmulTiling2dEagerKernel { config: Tiling2dConfig, - assumption: Tiling2DAssumption, + bounds_check_required: bool, _runtime: PhantomData, } @@ -41,7 +41,7 @@ impl DynamicKernelSource for MatmulTiling2dEagerKernel { MatmulTiling2dShader { variables: gpu::BinaryOperator { lhs, rhs, out }, config: self.config.clone(), - assumption: self.assumption.clone(), + bounds_check_required: self.bounds_check_required, unroll: false, } .expand(&mut scope); @@ -76,10 +76,10 @@ impl DynamicKernelSource for MatmulTiling2dEagerKernel { fn id(&self) -> String { format!( - "{:?}config={:?}assumption={:?}", + "{:?}config={:?}boundcheck={:?}", core::any::TypeId::of::(), self.config, - self.assumption + self.bounds_check_required ) } } @@ -92,9 +92,9 @@ pub fn matmul_tiling_2d( out: JitTensor, config: Tiling2dConfig, ) -> JitTensor { - let assumption = check_assumption(&lhs.shape, &rhs.shape, &config); + let bounds_check_required = check_bound_requirement(&lhs.shape, &rhs.shape, &config); - let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), assumption); + let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), bounds_check_required); let client = lhs.client.clone(); let lhs = match lhs.batch_swapped_with_row_col() { @@ -126,7 +126,7 @@ pub fn matmul_tiling_2d_padded, config: Tiling2dConfig, ) -> JitTensor { - let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), Tiling2DAssumption::Round); + let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), false); let client = lhs.client.clone(); // A tensor may need to be padded, in which case it will implicitly become contiguous @@ -177,16 +177,12 @@ pub fn matmul_tiling_2d_padded( +fn check_bound_requirement( lhs_shape: &Shape, rhs_shape: &Shape, config: &Tiling2dConfig, -) -> Tiling2DAssumption { - let m_divisible = lhs_shape.dims[D - 2] % config.block_size_m == 0; - let k_divisible = lhs_shape.dims[D - 1] % config.block_size_k == 0; - let n_divisible = rhs_shape.dims[D - 1] % config.block_size_n == 0; - match m_divisible && k_divisible && n_divisible { - true => Tiling2DAssumption::Round, - false => Tiling2DAssumption::None, - } +) -> bool { + !(lhs_shape.dims[D - 2] % config.block_size_m == 0) + || !(lhs_shape.dims[D - 1] % config.block_size_k == 0) + || !(rhs_shape.dims[D - 1] % config.block_size_n == 0) } diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs index 539218baeb..febd91f8d2 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs @@ -1,11 +1,11 @@ use crate::gpu::{gpu, BinaryOperator, Elem, Item, Scope, Synchronization, Variable}; -use super::{Tiling2DAssumption, Tiling2dConfig}; +use super::Tiling2dConfig; pub(crate) struct MatmulTiling2dShader { pub variables: BinaryOperator, pub config: Tiling2dConfig, - pub assumption: Tiling2DAssumption, + pub bounds_check_required: bool, pub unroll: bool, } @@ -140,20 +140,17 @@ impl MatmulTiling2dShader { ); let n_loops = scope.create_local(Elem::UInt); - match self.assumption { - Tiling2DAssumption::Round => { - gpu!(scope, n_loops = dim_k / block_size_k); - } - Tiling2DAssumption::None => { - let dim_k_float = scope.create_local(elem); - let block_size_k_float = scope.create_local(elem); - let n_loops_float = scope.create_local(elem); - gpu!(scope, dim_k_float = dim_k); - gpu!(scope, block_size_k_float = block_size_k); - gpu!(scope, n_loops_float = dim_k_float / block_size_k_float); - gpu!(scope, n_loops_float = ceil(n_loops_float)); - gpu!(scope, n_loops = n_loops_float); - } + if self.bounds_check_required { + let dim_k_float = scope.create_local(elem); + let block_size_k_float = scope.create_local(elem); + let n_loops_float = scope.create_local(elem); + gpu!(scope, dim_k_float = dim_k); + gpu!(scope, block_size_k_float = block_size_k); + gpu!(scope, n_loops_float = dim_k_float / block_size_k_float); + gpu!(scope, n_loops_float = ceil(n_loops_float)); + gpu!(scope, n_loops = n_loops_float); + } else { + gpu!(scope, n_loops = dim_k / block_size_k); } gpu!( @@ -163,71 +160,68 @@ impl MatmulTiling2dShader { let k = scope.create_local(Elem::UInt); gpu!(scope, k = i * block_size_k); - match self.assumption { - Tiling2DAssumption::Round => { - // LHS - self.load_shared_memory( - scope, - k, - thread_col, - thread_row, - lhs_stride_col, - lhs_stride_row, - lhs, - offset_lhs, - shared_lhs, - true, - ); - - // RHS - self.load_shared_memory( - scope, - k, - thread_row, - thread_col, - rhs_stride_row, - rhs_stride_col, - rhs, - offset_rhs, - shared_rhs, - false, - ); - } - Tiling2DAssumption::None => { - // LHS - self.load_shared_memory_with_bound_check( - scope, - k, - dim_k, - thread_col, - thread_row, - lhs_stride_col, - lhs_stride_row, - dim_m, - row, - lhs, - offset_lhs, - shared_lhs, - true, - ); - - // RHS - self.load_shared_memory_with_bound_check( - scope, - k, - dim_k, - thread_row, - thread_col, - rhs_stride_row, - rhs_stride_col, - dim_n, - col, - rhs, - offset_rhs, - shared_rhs, - false, - ); - } + if self.bounds_check_required { + // LHS + self.load_shared_memory_with_bound_check( + scope, + k, + dim_k, + thread_col, + thread_row, + lhs_stride_col, + lhs_stride_row, + dim_m, + row, + lhs, + offset_lhs, + shared_lhs, + true, + ); + + // RHS + self.load_shared_memory_with_bound_check( + scope, + k, + dim_k, + thread_row, + thread_col, + rhs_stride_row, + rhs_stride_col, + dim_n, + col, + rhs, + offset_rhs, + shared_rhs, + false, + ); + } else { + // LHS + self.load_shared_memory( + scope, + k, + thread_col, + thread_row, + lhs_stride_col, + lhs_stride_row, + lhs, + offset_lhs, + shared_lhs, + true, + ); + + // RHS + self.load_shared_memory( + scope, + k, + thread_row, + thread_col, + rhs_stride_row, + rhs_stride_col, + rhs, + offset_rhs, + shared_rhs, + false, + ); } scope.register(Synchronization::WorkgroupBarrier); @@ -554,8 +548,28 @@ impl MatmulTiling2dShader { gpu!(scope, row_index = row + res_idx_m); gpu!(scope, col_index = col + res_idx_n); - match self.assumption { - Tiling2DAssumption::Round => self.write_inner( + if self.bounds_check_required { + let within_output = scope.create_local(Elem::Bool); + let within_output_tmp = scope.create_local(Elem::Bool); + gpu!(scope, within_output = row_index < dim_m); + gpu!(scope, within_output_tmp = col_index < dim_n); + gpu!(scope, within_output = within_output && within_output_tmp); + gpu!(scope, if(within_output).then(|scope|{ + self.write_inner( + scope, + res_idx_m, + res_idx_n, + row_index, + col_index, + out_stride_row, + out_stride_col, + results, + offset_output, + out, + ); + })); + } else { + self.write_inner( scope, res_idx_m, res_idx_n, @@ -566,31 +580,7 @@ impl MatmulTiling2dShader { results, offset_output, out, - ), - Tiling2DAssumption::None => { - let within_output = scope.create_local(Elem::Bool); - let within_output_tmp = scope.create_local(Elem::Bool); - gpu!(scope, within_output = row_index < dim_m); - gpu!(scope, within_output_tmp = col_index < dim_n); - gpu!( - scope, - within_output = within_output && within_output_tmp - ); - gpu!(scope, if(within_output).then(|scope|{ - self.write_inner( - scope, - res_idx_m, - res_idx_n, - row_index, - col_index, - out_stride_row, - out_stride_col, - results, - offset_output, - out, - ); - })); - } + ) } } ) From 651010ac297ac0b8fd5c47e7ada022f3ffadefcf Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 19 Mar 2024 09:13:14 -0400 Subject: [PATCH 16/25] lhs rhs as enum --- .../src/kernel/matmul/tiling2d_shader.rs | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs index febd91f8d2..02aa511347 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs @@ -9,6 +9,11 @@ pub(crate) struct MatmulTiling2dShader { pub unroll: bool, } +enum InputIdentifier { + Lhs, + Rhs, +} + impl MatmulTiling2dShader { pub(crate) fn expand(self, scope: &mut Scope) { // Inputs @@ -161,8 +166,8 @@ impl MatmulTiling2dShader { gpu!(scope, k = i * block_size_k); if self.bounds_check_required { - // LHS self.load_shared_memory_with_bound_check( + InputIdentifier::Lhs, scope, k, dim_k, @@ -175,11 +180,10 @@ impl MatmulTiling2dShader { lhs, offset_lhs, shared_lhs, - true, ); - // RHS self.load_shared_memory_with_bound_check( + InputIdentifier::Rhs, scope, k, dim_k, @@ -192,11 +196,10 @@ impl MatmulTiling2dShader { rhs, offset_rhs, shared_rhs, - false, ); } else { - // LHS self.load_shared_memory( + InputIdentifier::Lhs, scope, k, thread_col, @@ -206,11 +209,10 @@ impl MatmulTiling2dShader { lhs, offset_lhs, shared_lhs, - true, ); - // RHS self.load_shared_memory( + InputIdentifier::Rhs, scope, k, thread_row, @@ -220,7 +222,6 @@ impl MatmulTiling2dShader { rhs, offset_rhs, shared_rhs, - false, ); } @@ -253,6 +254,7 @@ impl MatmulTiling2dShader { #[allow(clippy::too_many_arguments)] fn load_shared_memory_with_bound_check( &self, + input_identifier: InputIdentifier, scope: &mut Scope, k: Variable, dim_k: Variable, @@ -265,7 +267,6 @@ impl MatmulTiling2dShader { input: Variable, input_offset: Variable, shared_memory: Variable, - is_lhs: bool, ) { // How close is the thread to the end of the matrix. // If < 4 then it is an edge case @@ -291,14 +292,17 @@ impl MatmulTiling2dShader { // Position in shared memory let sm_position = scope.create_local(Elem::UInt); - if is_lhs { - gpu!(scope, sm_position = thread_idx_2 / 4u32); - gpu!(scope, sm_position *= block_size_k); - gpu!(scope, sm_position += current); - } else { - gpu!(scope, sm_position = current * block_size_n); - gpu!(scope, sm_position += thread_idx_2); - gpu!(scope, sm_position = sm_position / 4u32); + match input_identifier { + InputIdentifier::Lhs => { + gpu!(scope, sm_position = thread_idx_2 / 4u32); + gpu!(scope, sm_position *= block_size_k); + gpu!(scope, sm_position += current); + }, + InputIdentifier::Rhs => { + gpu!(scope, sm_position = current * block_size_n); + gpu!(scope, sm_position += thread_idx_2); + gpu!(scope, sm_position = sm_position / 4u32); + } } // To pad with zeros if outside lhs @@ -375,6 +379,7 @@ impl MatmulTiling2dShader { #[allow(clippy::too_many_arguments)] fn load_shared_memory( &self, + input_identifier: InputIdentifier, scope: &mut Scope, k: Variable, thread_idx_1: Variable, @@ -384,7 +389,6 @@ impl MatmulTiling2dShader { input: Variable, input_offset: Variable, shared_memory: Variable, - is_lhs: bool, ) { let block_size_k: Variable = self.config.block_size_k.into(); let block_size_n: Variable = self.config.block_size_n.into(); @@ -403,14 +407,17 @@ impl MatmulTiling2dShader { gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ let sm_position = scope.create_local(Elem::UInt); - if is_lhs { - gpu!(scope, sm_position = thread_idx_2 / 4u32); - gpu!(scope, sm_position *= block_size_k); - gpu!(scope, sm_position += current); - } else { - gpu!(scope, sm_position = current * block_size_n); - gpu!(scope, sm_position += thread_idx_2); - gpu!(scope, sm_position = sm_position / 4u32); + match input_identifier { + InputIdentifier::Lhs => { + gpu!(scope, sm_position = thread_idx_2 / 4u32); + gpu!(scope, sm_position *= block_size_k); + gpu!(scope, sm_position += current); + }, + InputIdentifier::Rhs => { + gpu!(scope, sm_position = current * block_size_n); + gpu!(scope, sm_position += thread_idx_2); + gpu!(scope, sm_position = sm_position / 4u32); + } } let position_0 = scope.create_local(Elem::UInt); From 5e8786b5c768cf8444cd27f6f13bbff8b457a947 Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 19 Mar 2024 11:16:13 -0400 Subject: [PATCH 17/25] tiling 2d major refactor --- crates/burn-jit/src/kernel/matmul/mod.rs | 1 - crates/burn-jit/src/kernel/matmul/tiling2d.rs | 3 +- .../src/kernel/matmul/tiling2d_shader.rs | 633 ------------------ .../src/kernel/matmul/tiling2d_shader/base.rs | 68 ++ .../matmul/tiling2d_shader/computation.rs | 79 +++ .../tiling2d_shader/load_shared_memory.rs | 264 ++++++++ .../src/kernel/matmul/tiling2d_shader/mod.rs | 11 + .../tiling2d_shader/shader_information.rs | 181 +++++ .../matmul/tiling2d_shader/write_output.rs | 99 +++ 9 files changed, 703 insertions(+), 636 deletions(-) delete mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/mod.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs diff --git a/crates/burn-jit/src/kernel/matmul/mod.rs b/crates/burn-jit/src/kernel/matmul/mod.rs index ba4830b2e1..324827ea41 100644 --- a/crates/burn-jit/src/kernel/matmul/mod.rs +++ b/crates/burn-jit/src/kernel/matmul/mod.rs @@ -20,4 +20,3 @@ pub mod padding; mod padding; pub use tiling2d::*; -use tiling2d_shader::*; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index ddc57bded5..5f663e4f13 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -13,8 +13,7 @@ use crate::{ use std::marker::PhantomData; use super::{ - padding::{crop, pad_round, PaddingOutput}, - shape_out, tiling2d_launch_options, MatmulTiling2dShader, Tiling2dConfig, + padding::{crop, pad_round, PaddingOutput}, shape_out, tiling2d_launch_options, tiling2d_shader::MatmulTiling2dShader, Tiling2dConfig }; #[derive(new, Debug)] diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs deleted file mode 100644 index 02aa511347..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader.rs +++ /dev/null @@ -1,633 +0,0 @@ -use crate::gpu::{gpu, BinaryOperator, Elem, Item, Scope, Synchronization, Variable}; - -use super::Tiling2dConfig; - -pub(crate) struct MatmulTiling2dShader { - pub variables: BinaryOperator, - pub config: Tiling2dConfig, - pub bounds_check_required: bool, - pub unroll: bool, -} - -enum InputIdentifier { - Lhs, - Rhs, -} - -impl MatmulTiling2dShader { - pub(crate) fn expand(self, scope: &mut Scope) { - // Inputs - let lhs = self.variables.lhs; - let rhs = self.variables.rhs; - let out = self.variables.out; - - // Config variables - let block_size_m: Variable = self.config.block_size_m.into(); - let block_size_k: Variable = self.config.block_size_k.into(); - let block_size_n: Variable = self.config.block_size_n.into(); - let tile_size_m: Variable = self.config.tile_size_m.into(); - let tile_size_n: Variable = self.config.tile_size_n.into(); - let n_threads_per_row: Variable = - (((self.config.block_size_n - 1) / self.config.tile_size_n) + 1).into(); - let results_size = (self.config.tile_size_m * self.config.tile_size_n) as u32; - - // Shader info - let local_idx = Variable::LocalInvocationIndex; - let batch = Variable::GlobalInvocationIdZ; - - // Shapes - let rank = Variable::Rank; - let ultimate_dim = scope.create_local(Elem::UInt); - let penultimate_dim = scope.create_local(Elem::UInt); - gpu!(scope, ultimate_dim = rank - 1u32); - gpu!(scope, penultimate_dim = rank - 2u32); - let dim_m = scope.create_local(Elem::UInt); - let dim_k = scope.create_local(Elem::UInt); - let dim_n = scope.create_local(Elem::UInt); - gpu!(scope, dim_m = shape(lhs, penultimate_dim)); - gpu!(scope, dim_k = shape(lhs, ultimate_dim)); - gpu!(scope, dim_n = shape(rhs, ultimate_dim)); - - // Strides - let lhs_stride_row = scope.create_local(Elem::UInt); - let lhs_stride_col = scope.create_local(Elem::UInt); - let rhs_stride_row = scope.create_local(Elem::UInt); - let rhs_stride_col = scope.create_local(Elem::UInt); - let out_stride_row = scope.create_local(Elem::UInt); - let out_stride_col = scope.create_local(Elem::UInt); - gpu!(scope, lhs_stride_row = stride(lhs, penultimate_dim)); - gpu!(scope, lhs_stride_col = stride(lhs, ultimate_dim)); - gpu!(scope, rhs_stride_row = stride(rhs, penultimate_dim)); - gpu!(scope, rhs_stride_col = stride(rhs, ultimate_dim)); - gpu!(scope, out_stride_row = stride(out, penultimate_dim)); - gpu!(scope, out_stride_col = stride(out, ultimate_dim)); - - // Workgroup offset - let skip_row = scope.create_local(Elem::UInt); - let workgroup_id_x = Variable::WorkgroupIdX; - gpu!(scope, skip_row = workgroup_id_x); - gpu!(scope, skip_row *= block_size_m); - let skip_col = scope.create_local(Elem::UInt); - let workgroup_id_y = Variable::WorkgroupIdY; - gpu!(scope, skip_col = workgroup_id_y); - gpu!(scope, skip_col *= block_size_n); - - // Position of the first element of the thread, relative to the block - let thread_row = scope.create_local(Elem::UInt); - gpu!(scope, thread_row = local_idx / n_threads_per_row); - gpu!(scope, thread_row *= tile_size_m); - let thread_col = scope.create_local(Elem::UInt); - gpu!(scope, thread_col = local_idx % n_threads_per_row); - gpu!(scope, thread_col *= tile_size_n); - - // Position of the first element of the thread, in absolute (in one batch) - let row = scope.create_local(Elem::UInt); - let col = scope.create_local(Elem::UInt); - gpu!(scope, row = skip_row + thread_row); - gpu!(scope, col = skip_col + thread_col); - - // Calculate offset. - let offset_lhs = scope.create_local(Elem::UInt); - let offset_rhs = scope.create_local(Elem::UInt); - gpu!(scope, offset_lhs = skip_row * lhs_stride_row); - gpu!(scope, offset_rhs = skip_col * rhs_stride_col); - - // Batch offset for the output. - let offset_output = scope.create_local(Elem::UInt); - let batch_dims = scope.create_local(Elem::UInt); - gpu!(scope, offset_output = dim_m * dim_n); - gpu!(scope, offset_output = offset_output * batch); - - // Batch offset for the lhs & rhs matrices. - gpu!(scope, batch_dims = rank - 2u32); - gpu!( - scope, - range(0u32, batch_dims).for_each(|b, scope| { - let stride_lhs = scope.create_local(Elem::UInt); - let stride_rhs = scope.create_local(Elem::UInt); - let stride_output = scope.create_local(Elem::UInt); - let shape_lhs = scope.create_local(Elem::UInt); - let shape_rhs = scope.create_local(Elem::UInt); - - gpu!(scope, stride_lhs = stride(lhs, b)); - gpu!(scope, stride_rhs = stride(rhs, b)); - gpu!(scope, stride_output = stride(out, b)); - gpu!(scope, shape_lhs = shape(lhs, b)); - gpu!(scope, shape_rhs = shape(rhs, b)); - - let tmp = scope.create_local(Elem::UInt); - gpu!(scope, tmp = offset_output / stride_output); - let tmp_lhs = scope.create_local(Elem::UInt); - gpu!(scope, tmp_lhs = tmp % shape_lhs); - gpu!(scope, tmp_lhs = tmp_lhs * stride_lhs); - gpu!(scope, offset_lhs += tmp_lhs); - - let tmp_rhs = scope.create_local(Elem::UInt); - gpu!(scope, tmp_rhs = tmp % shape_rhs); - gpu!(scope, tmp_rhs = tmp_rhs * stride_rhs); - gpu!(scope, offset_rhs += tmp_rhs); - }) - ); - - let elem = lhs.item().elem(); - - // Registers used in the compute pass - let results = scope.create_local_array(elem, results_size); - let register_m = scope.create_local(Item::Vec4(elem)); - let register_n = scope.create_local(Item::Vec4(elem)); - let shared_lhs = scope.create_shared( - Item::Vec4(elem), - self.config.block_size_m as u32 * self.config.block_size_k as u32 / 4u32, - ); - let shared_rhs = scope.create_shared( - Item::Vec4(elem), - self.config.block_size_k as u32 * self.config.block_size_n as u32 / 4u32, - ); - - let n_loops = scope.create_local(Elem::UInt); - if self.bounds_check_required { - let dim_k_float = scope.create_local(elem); - let block_size_k_float = scope.create_local(elem); - let n_loops_float = scope.create_local(elem); - gpu!(scope, dim_k_float = dim_k); - gpu!(scope, block_size_k_float = block_size_k); - gpu!(scope, n_loops_float = dim_k_float / block_size_k_float); - gpu!(scope, n_loops_float = ceil(n_loops_float)); - gpu!(scope, n_loops = n_loops_float); - } else { - gpu!(scope, n_loops = dim_k / block_size_k); - } - - gpu!( - scope, - range(0u32, n_loops).for_each(|i, scope| { - // Equivalent of looping from 0 to K with steps block_size_k - let k = scope.create_local(Elem::UInt); - gpu!(scope, k = i * block_size_k); - - if self.bounds_check_required { - self.load_shared_memory_with_bound_check( - InputIdentifier::Lhs, - scope, - k, - dim_k, - thread_col, - thread_row, - lhs_stride_col, - lhs_stride_row, - dim_m, - row, - lhs, - offset_lhs, - shared_lhs, - ); - - self.load_shared_memory_with_bound_check( - InputIdentifier::Rhs, - scope, - k, - dim_k, - thread_row, - thread_col, - rhs_stride_row, - rhs_stride_col, - dim_n, - col, - rhs, - offset_rhs, - shared_rhs, - ); - } else { - self.load_shared_memory( - InputIdentifier::Lhs, - scope, - k, - thread_col, - thread_row, - lhs_stride_col, - lhs_stride_row, - lhs, - offset_lhs, - shared_lhs, - ); - - self.load_shared_memory( - InputIdentifier::Rhs, - scope, - k, - thread_row, - thread_col, - rhs_stride_row, - rhs_stride_col, - rhs, - offset_rhs, - shared_rhs, - ); - } - - scope.register(Synchronization::WorkgroupBarrier); - - self.computation_loop( - scope, thread_col, thread_row, shared_lhs, shared_rhs, register_m, register_n, - results, - ); - - scope.register(Synchronization::WorkgroupBarrier); - }) - ); - - // Phase 3: Write to output - self.write_to_output( - scope, - row, - col, - dim_m, - dim_n, - out_stride_row, - out_stride_col, - results, - offset_output, - out, - ); - } - - #[allow(clippy::too_many_arguments)] - fn load_shared_memory_with_bound_check( - &self, - input_identifier: InputIdentifier, - scope: &mut Scope, - k: Variable, - dim_k: Variable, - thread_idx_1: Variable, - thread_idx_2: Variable, - stride_1: Variable, - stride_2: Variable, - dim: Variable, - pos_in_dim: Variable, - input: Variable, - input_offset: Variable, - shared_memory: Variable, - ) { - // How close is the thread to the end of the matrix. - // If < 4 then it is an edge case - - let remain = scope.create_local(Elem::UInt); - gpu!(scope, remain = dim - pos_in_dim); - - let block_size_k: Variable = self.config.block_size_k.into(); - let block_size_n: Variable = self.config.block_size_n.into(); - let elem = input.item().elem(); - - gpu!( - scope, - range(0_u32, 4u32, self.unroll).for_each(|j, scope| { - let current = scope.create_local(Elem::UInt); - gpu!(scope, current = thread_idx_1 + j); - - let aligned_with_shared_memory = scope.create_local(Elem::Bool); - gpu!(scope, aligned_with_shared_memory = current < block_size_k); - - // To avoid overwriting following row in shared memory - gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ - - // Position in shared memory - let sm_position = scope.create_local(Elem::UInt); - match input_identifier { - InputIdentifier::Lhs => { - gpu!(scope, sm_position = thread_idx_2 / 4u32); - gpu!(scope, sm_position *= block_size_k); - gpu!(scope, sm_position += current); - }, - InputIdentifier::Rhs => { - gpu!(scope, sm_position = current * block_size_n); - gpu!(scope, sm_position += thread_idx_2); - gpu!(scope, sm_position = sm_position / 4u32); - } - } - - // To pad with zeros if outside lhs - let within_input = scope.create_local(Elem::Bool); - let current_with_k = scope.create_local(Elem::UInt); - let remain_at_least_1 = scope.create_local(Elem::Bool); - let read_condition = scope.create_local(Elem::Bool); - gpu!(scope, current_with_k = current + k); - gpu!(scope, within_input = current_with_k < dim_k); - gpu!(scope, remain_at_least_1 = remain >= 1u32); - gpu!(scope, read_condition = within_input && remain_at_least_1); - - gpu!(scope, if(read_condition).then(|scope| { - let position_0 = scope.create_local(Elem::UInt); - gpu!(scope, position_0 = k + current); - gpu!(scope, position_0 *= stride_1); - let tmp = scope.create_local(Elem::UInt); - gpu!(scope, tmp = thread_idx_2 * stride_2); - gpu!(scope, position_0 += tmp); - gpu!(scope, position_0 += input_offset); - let position_1 = scope.create_local(Elem::UInt); - let position_2 = scope.create_local(Elem::UInt); - let position_3 = scope.create_local(Elem::UInt); - gpu!(scope, position_1 = position_0 + stride_2); - gpu!(scope, position_2 = position_1 + stride_2); - gpu!(scope, position_3 = position_2 + stride_2); - - let val_0 = scope.zero(elem); - let val_1 = scope.zero(elem); - let val_2 = scope.zero(elem); - let val_3 = scope.zero(elem); - - let remain_n = scope.create_local(Elem::Bool); - gpu!(scope, remain_n = remain >= 4u32); - gpu!(scope, if(remain_n).then(|scope|{ - gpu!(scope, val_0 = input[position_0]); - gpu!(scope, val_1 = input[position_1]); - gpu!(scope, val_2 = input[position_2]); - gpu!(scope, val_3 = input[position_3]); - }).else(|scope|{ - gpu!(scope, remain_n = remain == 3u32); - gpu!(scope, if(remain_n).then(|scope|{ - gpu!(scope, val_0 = input[position_0]); - gpu!(scope, val_1 = input[position_1]); - gpu!(scope, val_2 = input[position_2]); - }).else(|scope|{ - gpu!(scope, remain_n = remain == 2u32); - gpu!(scope, if(remain_n).then(|scope|{ - gpu!(scope, val_0 = input[position_0]); - gpu!(scope, val_1 = input[position_1]); - }).else(|scope|{ - gpu!(scope, remain_n = remain == 1u32); - gpu!(scope, if(remain_n).then(|scope|{ - gpu!(scope, val_0 = input[position_0]); - })); - })); - })); - })); - - let val_vec4 = scope.create_local(shared_memory.item()); - gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); - gpu!(scope, shared_memory[sm_position] = val_vec4); - }).else(|scope|{ - let val_0 = scope.zero(elem); - let val_vec4 = scope.create_local(shared_memory.item()); - gpu!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0)); - gpu!(scope, shared_memory[sm_position] = val_vec4); - })); - })); - }) - ); - } - - #[allow(clippy::too_many_arguments)] - fn load_shared_memory( - &self, - input_identifier: InputIdentifier, - scope: &mut Scope, - k: Variable, - thread_idx_1: Variable, - thread_idx_2: Variable, - stride_1: Variable, - stride_2: Variable, - input: Variable, - input_offset: Variable, - shared_memory: Variable, - ) { - let block_size_k: Variable = self.config.block_size_k.into(); - let block_size_n: Variable = self.config.block_size_n.into(); - let elem = input.item().elem(); - - gpu!( - scope, - range(0_u32, 4u32, self.unroll).for_each(|j, scope| { - let current = scope.create_local(Elem::UInt); - gpu!(scope, current = thread_idx_1 + j); - - let aligned_with_shared_memory = scope.create_local(Elem::Bool); - gpu!(scope, aligned_with_shared_memory = current < block_size_k); - - // To avoid overwriting following row in shared memory - gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ - - let sm_position = scope.create_local(Elem::UInt); - match input_identifier { - InputIdentifier::Lhs => { - gpu!(scope, sm_position = thread_idx_2 / 4u32); - gpu!(scope, sm_position *= block_size_k); - gpu!(scope, sm_position += current); - }, - InputIdentifier::Rhs => { - gpu!(scope, sm_position = current * block_size_n); - gpu!(scope, sm_position += thread_idx_2); - gpu!(scope, sm_position = sm_position / 4u32); - } - } - - let position_0 = scope.create_local(Elem::UInt); - gpu!(scope, position_0 = k + current); - gpu!(scope, position_0 *= stride_1); - let tmp = scope.create_local(Elem::UInt); - gpu!(scope, tmp = thread_idx_2 * stride_2); - gpu!(scope, position_0 += tmp); - gpu!(scope, position_0 += input_offset); - let position_1 = scope.create_local(Elem::UInt); - let position_2 = scope.create_local(Elem::UInt); - let position_3 = scope.create_local(Elem::UInt); - gpu!(scope, position_1 = position_0 + stride_2); - gpu!(scope, position_2 = position_1 + stride_2); - gpu!(scope, position_3 = position_2 + stride_2); - - let val_0 = scope.create_local(elem); - let val_1 = scope.create_local(elem); - let val_2 = scope.create_local(elem); - let val_3 = scope.create_local(elem); - gpu!(scope, val_0 = input[position_0]); - gpu!(scope, val_1 = input[position_1]); - gpu!(scope, val_2 = input[position_2]); - gpu!(scope, val_3 = input[position_3]); - - let val_vec4 = scope.create_local(shared_memory.item()); - gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); - gpu!(scope, shared_memory[sm_position] = val_vec4); - })); - }) - ); - } - - #[allow(clippy::too_many_arguments)] - fn computation_loop( - &self, - scope: &mut Scope, - thread_col: Variable, - thread_row: Variable, - shared_lhs: Variable, - shared_rhs: Variable, - register_m: Variable, - register_n: Variable, - results: Variable, - ) { - let block_size_k: Variable = self.config.block_size_k.into(); - let block_size_n: Variable = self.config.block_size_n.into(); - let elem = results.item().elem(); - - gpu!( - scope, - range(0u32, self.config.block_size_k as u32, self.unroll).for_each( - |dot_index, scope| { - // Load a subcolumn of values from lhs - let lhs_sm_position = scope.create_local(Elem::UInt); - gpu!(scope, lhs_sm_position = thread_row / 4u32); - gpu!(scope, lhs_sm_position *= block_size_k); - gpu!(scope, lhs_sm_position += dot_index); - gpu!(scope, register_m = shared_lhs[lhs_sm_position]); - - // Load a subrow of values from rhs - let rhs_sm_position = scope.create_local(Elem::UInt); - gpu!(scope, rhs_sm_position = dot_index * block_size_n); - gpu!(scope, rhs_sm_position += thread_col); - gpu!(scope, rhs_sm_position = rhs_sm_position / 4u32); - gpu!(scope, register_n = shared_rhs[rhs_sm_position]); - - gpu!( - scope, - range(0u32, self.config.tile_size_m as u32, self.unroll).for_each( - |res_idx_m, scope| { - gpu!( - scope, - range(0u32, self.config.tile_size_n as u32, self.unroll) - .for_each(|res_idx_n, scope| { - let registered_m = scope.create_local(elem); - let registered_n = scope.create_local(elem); - gpu!(scope, registered_m = register_m[res_idx_m]); - gpu!(scope, registered_n = register_n[res_idx_n]); - - let multiplied = scope.create_local(elem); - gpu!(scope, multiplied = registered_m * registered_n); - - let results_position = scope.create_local(Elem::UInt); - gpu!( - scope, - results_position = - res_idx_m * self.config.tile_size_n - ); - gpu!(scope, results_position += res_idx_n); - - let results_before = scope.create_local(elem); - gpu!(scope, results_before = results[results_position]); - let results_after = scope.create_local(elem); - gpu!( - scope, - results_after = results_before + multiplied - ); - - gpu!(scope, results[results_position] = results_after); - }) - ); - } - ) - ); - } - ) - ); - } - - #[allow(clippy::too_many_arguments)] - fn write_to_output( - &self, - scope: &mut Scope, - row: Variable, - col: Variable, - dim_m: Variable, - dim_n: Variable, - out_stride_row: Variable, - out_stride_col: Variable, - results: Variable, - offset_output: Variable, - out: Variable, - ) { - gpu!( - scope, - range(0u32, self.config.tile_size_m as u32, self.unroll).for_each( - |res_idx_m, scope| { - gpu!( - scope, - range(0u32, self.config.tile_size_n as u32, self.unroll).for_each( - |res_idx_n, scope| { - let row_index = scope.create_local(Elem::UInt); - let col_index = scope.create_local(Elem::UInt); - gpu!(scope, row_index = row + res_idx_m); - gpu!(scope, col_index = col + res_idx_n); - - if self.bounds_check_required { - let within_output = scope.create_local(Elem::Bool); - let within_output_tmp = scope.create_local(Elem::Bool); - gpu!(scope, within_output = row_index < dim_m); - gpu!(scope, within_output_tmp = col_index < dim_n); - gpu!(scope, within_output = within_output && within_output_tmp); - gpu!(scope, if(within_output).then(|scope|{ - self.write_inner( - scope, - res_idx_m, - res_idx_n, - row_index, - col_index, - out_stride_row, - out_stride_col, - results, - offset_output, - out, - ); - })); - } else { - self.write_inner( - scope, - res_idx_m, - res_idx_n, - row_index, - col_index, - out_stride_row, - out_stride_col, - results, - offset_output, - out, - ) - } - } - ) - ); - } - ) - ); - } - - #[allow(clippy::too_many_arguments)] - fn write_inner( - &self, - scope: &mut Scope, - res_idx_m: Variable, - res_idx_n: Variable, - row_index: Variable, - col_index: Variable, - out_stride_row: Variable, - out_stride_col: Variable, - results: Variable, - offset_output: Variable, - out: Variable, - ) { - let elem = results.item().elem(); - let results_position = scope.create_local(Elem::UInt); - gpu!( - scope, - results_position = res_idx_m * self.config.tile_size_n - ); - gpu!(scope, results_position += res_idx_n); - - let result = scope.create_local(elem); - gpu!(scope, result = results[results_position]); - - let output_position = scope.create_local(Elem::UInt); - gpu!(scope, row_index *= out_stride_row); - gpu!(scope, col_index *= out_stride_col); - gpu!(scope, output_position = row_index + col_index); - gpu!(scope, output_position += offset_output); - - gpu!(scope, out[output_position] = result); - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs new file mode 100644 index 0000000000..da94b887f1 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs @@ -0,0 +1,68 @@ +use crate::gpu::{gpu, BinaryOperator, Scope, Synchronization, Variable}; + +use crate::kernel::matmul::tiling2d_shader::{ + computation_loop, gather_shader_information, load_shared_memory, write_to_output, +}; +use crate::kernel::matmul::Tiling2dConfig; + +pub(crate) struct MatmulTiling2dShader { + pub variables: BinaryOperator, + pub config: Tiling2dConfig, + pub bounds_check_required: bool, + pub unroll: bool, +} + +pub(crate) struct Tiling2dState { + pub n_loops: Variable, + pub k: Variable, + pub lhs: Variable, + pub rhs: Variable, + pub out: Variable, + pub offset_lhs: Variable, + pub offset_rhs: Variable, + pub offset_output: Variable, + pub row: Variable, + pub col: Variable, + pub dim_m: Variable, + pub dim_k: Variable, + pub dim_n: Variable, + pub thread_col: Variable, + pub thread_row: Variable, + pub shared_lhs: Variable, + pub shared_rhs: Variable, + pub register_m: Variable, + pub register_n: Variable, + pub results: Variable, + pub lhs_stride_col: Variable, + pub lhs_stride_row: Variable, + pub rhs_stride_col: Variable, + pub rhs_stride_row: Variable, + pub out_stride_row: Variable, + pub out_stride_col: Variable, +} + +impl MatmulTiling2dShader { + pub(crate) fn expand(self, scope: &mut Scope) { + let shader_state = gather_shader_information(scope, &self); + + let block_size_k: Variable = self.config.block_size_k.into(); + gpu!( + scope, + range(0u32, shader_state.n_loops).for_each(|i, scope| { + // Equivalent of looping from 0 to K with steps block_size_k + let k = shader_state.k; + gpu!(scope, k = i * block_size_k); + + load_shared_memory(scope, &self, &shader_state); + + scope.register(Synchronization::WorkgroupBarrier); + + computation_loop(scope, &self, &shader_state); + + scope.register(Synchronization::WorkgroupBarrier); + }) + ); + + write_to_output(scope, &self, &shader_state); + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs new file mode 100644 index 0000000000..7d48ef824b --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs @@ -0,0 +1,79 @@ +use crate::gpu::{gpu, Elem, Scope, Variable}; + +use super::{MatmulTiling2dShader, Tiling2dState}; + +#[allow(clippy::too_many_arguments)] +pub fn computation_loop( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, +) { + let thread_col = shader_state.thread_col; + let thread_row = shader_state.thread_row; + let shared_lhs = shader_state.shared_lhs; + let shared_rhs = shader_state.shared_rhs; + let register_m = shader_state.register_m; + let register_n = shader_state.register_n; + let results = shader_state.results; + + let block_size_k: Variable = shader.config.block_size_k.into(); + let block_size_n: Variable = shader.config.block_size_n.into(); + let elem = results.item().elem(); + + gpu!( + scope, + range(0u32, shader.config.block_size_k as u32, shader.unroll).for_each( + |dot_index, scope| { + // Load a subcolumn of values from lhs + let lhs_sm_position = scope.create_local(Elem::UInt); + gpu!(scope, lhs_sm_position = thread_row / 4u32); + gpu!(scope, lhs_sm_position *= block_size_k); + gpu!(scope, lhs_sm_position += dot_index); + gpu!(scope, register_m = shared_lhs[lhs_sm_position]); + + // Load a subrow of values from rhs + let rhs_sm_position = scope.create_local(Elem::UInt); + gpu!(scope, rhs_sm_position = dot_index * block_size_n); + gpu!(scope, rhs_sm_position += thread_col); + gpu!(scope, rhs_sm_position = rhs_sm_position / 4u32); + gpu!(scope, register_n = shared_rhs[rhs_sm_position]); + + gpu!( + scope, + range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each( + |res_idx_m, scope| { + gpu!( + scope, + range(0u32, shader.config.tile_size_n as u32, shader.unroll) + .for_each(|res_idx_n, scope| { + let registered_m = scope.create_local(elem); + let registered_n = scope.create_local(elem); + gpu!(scope, registered_m = register_m[res_idx_m]); + gpu!(scope, registered_n = register_n[res_idx_n]); + + let multiplied = scope.create_local(elem); + gpu!(scope, multiplied = registered_m * registered_n); + + let results_position = scope.create_local(Elem::UInt); + gpu!( + scope, + results_position = + res_idx_m * shader.config.tile_size_n + ); + gpu!(scope, results_position += res_idx_n); + + let results_before = scope.create_local(elem); + gpu!(scope, results_before = results[results_position]); + let results_after = scope.create_local(elem); + gpu!(scope, results_after = results_before + multiplied); + + gpu!(scope, results[results_position] = results_after); + }) + ); + } + ) + ); + } + ) + ); +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs new file mode 100644 index 0000000000..6bbc16b751 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs @@ -0,0 +1,264 @@ +use crate::gpu::{gpu, Elem, Scope, Variable}; + +use super::{MatmulTiling2dShader, Tiling2dState}; + +enum InputIdentifier { + Lhs, + Rhs, +} + +pub(crate) fn load_shared_memory( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, +) { + if shader.bounds_check_required { + load_shared_memory_with_bound_check(scope, shader, shader_state, InputIdentifier::Lhs); + load_shared_memory_with_bound_check(scope, shader, shader_state, InputIdentifier::Rhs); + } else { + load_shared_memory_no_bound_check(scope, shader, shader_state, InputIdentifier::Lhs); + load_shared_memory_no_bound_check(scope, shader, shader_state, InputIdentifier::Rhs); + } +} + +#[allow(clippy::too_many_arguments)] +fn load_shared_memory_with_bound_check( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, + input_identifier: InputIdentifier, +) { + let ( + input, + input_offset, + shared_memory, + thread_idx_1, + thread_idx_2, + stride_1, + stride_2, + dim, + pos_in_dim, + ) = match input_identifier { + InputIdentifier::Lhs => ( + shader_state.lhs, + shader_state.offset_lhs, + shader_state.shared_lhs, + shader_state.thread_col, + shader_state.thread_row, + shader_state.lhs_stride_col, + shader_state.lhs_stride_row, + shader_state.dim_m, + shader_state.row, + ), + InputIdentifier::Rhs => ( + shader_state.rhs, + shader_state.offset_rhs, + shader_state.shared_rhs, + shader_state.thread_row, + shader_state.thread_col, + shader_state.rhs_stride_row, + shader_state.rhs_stride_col, + shader_state.dim_n, + shader_state.col, + ), + }; + let k = shader_state.k; + let dim_k = shader_state.dim_k; + + // How close is the thread to the end of the matrix. + // If < 4 then it is an edge case + let remain = scope.create_local(Elem::UInt); + gpu!(scope, remain = dim - pos_in_dim); + + let block_size_k: Variable = shader.config.block_size_k.into(); + let block_size_n: Variable = shader.config.block_size_n.into(); + let elem = input.item().elem(); + + gpu!( + scope, + range(0_u32, 4u32, shader.unroll).for_each(|j, scope| { + let current = scope.create_local(Elem::UInt); + gpu!(scope, current = thread_idx_1 + j); + + let aligned_with_shared_memory = scope.create_local(Elem::Bool); + gpu!(scope, aligned_with_shared_memory = current < block_size_k); + + // To avoid overwriting following row in shared memory + gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ + + // Position in shared memory + let sm_position = scope.create_local(Elem::UInt); + match input_identifier { + InputIdentifier::Lhs => { + gpu!(scope, sm_position = thread_idx_2 / 4u32); + gpu!(scope, sm_position *= block_size_k); + gpu!(scope, sm_position += current); + }, + InputIdentifier::Rhs => { + gpu!(scope, sm_position = current * block_size_n); + gpu!(scope, sm_position += thread_idx_2); + gpu!(scope, sm_position = sm_position / 4u32); + } + } + + // To pad with zeros if outside lhs + let within_input = scope.create_local(Elem::Bool); + let current_with_k = scope.create_local(Elem::UInt); + let remain_at_least_1 = scope.create_local(Elem::Bool); + let read_condition = scope.create_local(Elem::Bool); + gpu!(scope, current_with_k = current + k); + gpu!(scope, within_input = current_with_k < dim_k); + gpu!(scope, remain_at_least_1 = remain >= 1u32); + gpu!(scope, read_condition = within_input && remain_at_least_1); + + gpu!(scope, if(read_condition).then(|scope| { + let position_0 = scope.create_local(Elem::UInt); + gpu!(scope, position_0 = k + current); + gpu!(scope, position_0 *= stride_1); + let tmp = scope.create_local(Elem::UInt); + gpu!(scope, tmp = thread_idx_2 * stride_2); + gpu!(scope, position_0 += tmp); + gpu!(scope, position_0 += input_offset); + let position_1 = scope.create_local(Elem::UInt); + let position_2 = scope.create_local(Elem::UInt); + let position_3 = scope.create_local(Elem::UInt); + gpu!(scope, position_1 = position_0 + stride_2); + gpu!(scope, position_2 = position_1 + stride_2); + gpu!(scope, position_3 = position_2 + stride_2); + + let val_0 = scope.zero(elem); + let val_1 = scope.zero(elem); + let val_2 = scope.zero(elem); + let val_3 = scope.zero(elem); + + let remain_n = scope.create_local(Elem::Bool); + gpu!(scope, remain_n = remain >= 4u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + gpu!(scope, val_2 = input[position_2]); + gpu!(scope, val_3 = input[position_3]); + }).else(|scope|{ + gpu!(scope, remain_n = remain == 3u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + gpu!(scope, val_2 = input[position_2]); + }).else(|scope|{ + gpu!(scope, remain_n = remain == 2u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + }).else(|scope|{ + gpu!(scope, remain_n = remain == 1u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + })); + })); + })); + })); + + let val_vec4 = scope.create_local(shared_memory.item()); + gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); + gpu!(scope, shared_memory[sm_position] = val_vec4); + }).else(|scope|{ + let val_0 = scope.zero(elem); + let val_vec4 = scope.create_local(shared_memory.item()); + gpu!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0)); + gpu!(scope, shared_memory[sm_position] = val_vec4); + })); + })); + }) + ); +} + +#[allow(clippy::too_many_arguments)] +fn load_shared_memory_no_bound_check( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, + input_identifier: InputIdentifier, +) { + let (input, input_offset, shared_memory, thread_idx_1, thread_idx_2, stride_1, stride_2) = + match input_identifier { + InputIdentifier::Lhs => ( + shader_state.lhs, + shader_state.offset_lhs, + shader_state.shared_lhs, + shader_state.thread_col, + shader_state.thread_row, + shader_state.lhs_stride_col, + shader_state.lhs_stride_row, + ), + InputIdentifier::Rhs => ( + shader_state.rhs, + shader_state.offset_rhs, + shader_state.shared_rhs, + shader_state.thread_row, + shader_state.thread_col, + shader_state.rhs_stride_row, + shader_state.rhs_stride_col, + ), + }; + let k = shader_state.k; + + let block_size_k: Variable = shader.config.block_size_k.into(); + let block_size_n: Variable = shader.config.block_size_n.into(); + let elem = input.item().elem(); + + gpu!( + scope, + range(0_u32, 4u32, shader.unroll).for_each(|j, scope| { + let current = scope.create_local(Elem::UInt); + gpu!(scope, current = thread_idx_1 + j); + + let aligned_with_shared_memory = scope.create_local(Elem::Bool); + gpu!(scope, aligned_with_shared_memory = current < block_size_k); + + // To avoid overwriting following row in shared memory + gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ + + let sm_position = scope.create_local(Elem::UInt); + match input_identifier { + InputIdentifier::Lhs => { + gpu!(scope, sm_position = thread_idx_2 / 4u32); + gpu!(scope, sm_position *= block_size_k); + gpu!(scope, sm_position += current); + }, + InputIdentifier::Rhs => { + gpu!(scope, sm_position = current * block_size_n); + gpu!(scope, sm_position += thread_idx_2); + gpu!(scope, sm_position = sm_position / 4u32); + } + } + + let position_0 = scope.create_local(Elem::UInt); + gpu!(scope, position_0 = k + current); + gpu!(scope, position_0 *= stride_1); + let tmp = scope.create_local(Elem::UInt); + gpu!(scope, tmp = thread_idx_2 * stride_2); + gpu!(scope, position_0 += tmp); + gpu!(scope, position_0 += input_offset); + let position_1 = scope.create_local(Elem::UInt); + let position_2 = scope.create_local(Elem::UInt); + let position_3 = scope.create_local(Elem::UInt); + gpu!(scope, position_1 = position_0 + stride_2); + gpu!(scope, position_2 = position_1 + stride_2); + gpu!(scope, position_3 = position_2 + stride_2); + + let val_0 = scope.create_local(elem); + let val_1 = scope.create_local(elem); + let val_2 = scope.create_local(elem); + let val_3 = scope.create_local(elem); + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + gpu!(scope, val_2 = input[position_2]); + gpu!(scope, val_3 = input[position_3]); + + let val_vec4 = scope.create_local(shared_memory.item()); + gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); + gpu!(scope, shared_memory[sm_position] = val_vec4); + })); + }) + ); +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/mod.rs new file mode 100644 index 0000000000..3ed28903d7 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/mod.rs @@ -0,0 +1,11 @@ +mod base; +mod computation; +mod load_shared_memory; +mod shader_information; +mod write_output; + +pub(crate) use base::*; +pub(crate) use computation::*; +pub(crate) use load_shared_memory::*; +pub(crate) use shader_information::*; +pub(crate) use write_output::*; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs new file mode 100644 index 0000000000..bc22152d88 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs @@ -0,0 +1,181 @@ +use crate::gpu::{gpu, Elem, Item, Scope, Variable}; + +use super::{MatmulTiling2dShader, Tiling2dState}; + +pub(crate) fn gather_shader_information( + scope: &mut Scope, + shader: &MatmulTiling2dShader, +) -> Tiling2dState { + // Inputs + let lhs = shader.variables.lhs; + let rhs = shader.variables.rhs; + let out = shader.variables.out; + + // Config variables + let block_size_m: Variable = shader.config.block_size_m.into(); + let block_size_k: Variable = shader.config.block_size_k.into(); + let block_size_n: Variable = shader.config.block_size_n.into(); + let tile_size_m: Variable = shader.config.tile_size_m.into(); + let tile_size_n: Variable = shader.config.tile_size_n.into(); + let n_threads_per_row: Variable = + (((shader.config.block_size_n - 1) / shader.config.tile_size_n) + 1).into(); + let results_size = (shader.config.tile_size_m * shader.config.tile_size_n) as u32; + + // Shader info + let local_idx = Variable::LocalInvocationIndex; + let batch = Variable::GlobalInvocationIdZ; + + // Shapes + let rank = Variable::Rank; + let ultimate_dim = scope.create_local(Elem::UInt); + let penultimate_dim = scope.create_local(Elem::UInt); + gpu!(scope, ultimate_dim = rank - 1u32); + gpu!(scope, penultimate_dim = rank - 2u32); + let dim_m = scope.create_local(Elem::UInt); + let dim_k = scope.create_local(Elem::UInt); + let dim_n = scope.create_local(Elem::UInt); + gpu!(scope, dim_m = shape(lhs, penultimate_dim)); + gpu!(scope, dim_k = shape(lhs, ultimate_dim)); + gpu!(scope, dim_n = shape(rhs, ultimate_dim)); + + // Strides + let lhs_stride_row = scope.create_local(Elem::UInt); + let lhs_stride_col = scope.create_local(Elem::UInt); + let rhs_stride_row = scope.create_local(Elem::UInt); + let rhs_stride_col = scope.create_local(Elem::UInt); + let out_stride_row = scope.create_local(Elem::UInt); + let out_stride_col = scope.create_local(Elem::UInt); + gpu!(scope, lhs_stride_row = stride(lhs, penultimate_dim)); + gpu!(scope, lhs_stride_col = stride(lhs, ultimate_dim)); + gpu!(scope, rhs_stride_row = stride(rhs, penultimate_dim)); + gpu!(scope, rhs_stride_col = stride(rhs, ultimate_dim)); + gpu!(scope, out_stride_row = stride(out, penultimate_dim)); + gpu!(scope, out_stride_col = stride(out, ultimate_dim)); + + // Workgroup offset + let skip_row = scope.create_local(Elem::UInt); + let workgroup_id_x = Variable::WorkgroupIdX; + gpu!(scope, skip_row = workgroup_id_x); + gpu!(scope, skip_row *= block_size_m); + let skip_col = scope.create_local(Elem::UInt); + let workgroup_id_y = Variable::WorkgroupIdY; + gpu!(scope, skip_col = workgroup_id_y); + gpu!(scope, skip_col *= block_size_n); + + // Position of the first element of the thread, relative to the block + let thread_row = scope.create_local(Elem::UInt); + gpu!(scope, thread_row = local_idx / n_threads_per_row); + gpu!(scope, thread_row *= tile_size_m); + let thread_col = scope.create_local(Elem::UInt); + gpu!(scope, thread_col = local_idx % n_threads_per_row); + gpu!(scope, thread_col *= tile_size_n); + + // Position of the first element of the thread, in absolute (in one batch) + let row = scope.create_local(Elem::UInt); + let col = scope.create_local(Elem::UInt); + gpu!(scope, row = skip_row + thread_row); + gpu!(scope, col = skip_col + thread_col); + + // Calculate offset. + let offset_lhs = scope.create_local(Elem::UInt); + let offset_rhs = scope.create_local(Elem::UInt); + gpu!(scope, offset_lhs = skip_row * lhs_stride_row); + gpu!(scope, offset_rhs = skip_col * rhs_stride_col); + + // Batch offset for the output. + let offset_output = scope.create_local(Elem::UInt); + let batch_dims = scope.create_local(Elem::UInt); + gpu!(scope, offset_output = dim_m * dim_n); + gpu!(scope, offset_output = offset_output * batch); + + // Batch offset for the lhs & rhs matrices. + gpu!(scope, batch_dims = rank - 2u32); + gpu!( + scope, + range(0u32, batch_dims).for_each(|b, scope| { + let stride_lhs = scope.create_local(Elem::UInt); + let stride_rhs = scope.create_local(Elem::UInt); + let stride_output = scope.create_local(Elem::UInt); + let shape_lhs = scope.create_local(Elem::UInt); + let shape_rhs = scope.create_local(Elem::UInt); + + gpu!(scope, stride_lhs = stride(lhs, b)); + gpu!(scope, stride_rhs = stride(rhs, b)); + gpu!(scope, stride_output = stride(out, b)); + gpu!(scope, shape_lhs = shape(lhs, b)); + gpu!(scope, shape_rhs = shape(rhs, b)); + + let tmp = scope.create_local(Elem::UInt); + gpu!(scope, tmp = offset_output / stride_output); + let tmp_lhs = scope.create_local(Elem::UInt); + gpu!(scope, tmp_lhs = tmp % shape_lhs); + gpu!(scope, tmp_lhs = tmp_lhs * stride_lhs); + gpu!(scope, offset_lhs += tmp_lhs); + + let tmp_rhs = scope.create_local(Elem::UInt); + gpu!(scope, tmp_rhs = tmp % shape_rhs); + gpu!(scope, tmp_rhs = tmp_rhs * stride_rhs); + gpu!(scope, offset_rhs += tmp_rhs); + }) + ); + + let elem = lhs.item().elem(); + + // Registers used in the compute pass + let results = scope.create_local_array(elem, results_size); + let register_m = scope.create_local(Item::Vec4(elem)); + let register_n = scope.create_local(Item::Vec4(elem)); + let shared_lhs = scope.create_shared( + Item::Vec4(elem), + shader.config.block_size_m as u32 * shader.config.block_size_k as u32 / 4u32, + ); + let shared_rhs = scope.create_shared( + Item::Vec4(elem), + shader.config.block_size_k as u32 * shader.config.block_size_n as u32 / 4u32, + ); + + let n_loops = scope.create_local(Elem::UInt); + if shader.bounds_check_required { + let dim_k_float = scope.create_local(elem); + let block_size_k_float = scope.create_local(elem); + let n_loops_float = scope.create_local(elem); + gpu!(scope, dim_k_float = dim_k); + gpu!(scope, block_size_k_float = block_size_k); + gpu!(scope, n_loops_float = dim_k_float / block_size_k_float); + gpu!(scope, n_loops_float = ceil(n_loops_float)); + gpu!(scope, n_loops = n_loops_float); + } else { + gpu!(scope, n_loops = dim_k / block_size_k); + } + + let k = scope.create_local(Elem::UInt); + + Tiling2dState { + n_loops, + k, + lhs, + rhs, + out, + offset_lhs, + offset_rhs, + offset_output, + row, + col, + dim_m, + dim_k, + dim_n, + thread_col, + thread_row, + shared_lhs, + shared_rhs, + register_m, + register_n, + results, + lhs_stride_col, + lhs_stride_row, + rhs_stride_col, + rhs_stride_row, + out_stride_row, + out_stride_col, + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs new file mode 100644 index 0000000000..6884848579 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs @@ -0,0 +1,99 @@ +use crate::gpu::{gpu, Elem, Scope, Variable}; + +use super::{MatmulTiling2dShader, Tiling2dState}; + +#[allow(clippy::too_many_arguments)] +pub fn write_to_output( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, +) { + let row = shader_state.row; + let col = shader_state.col; + let dim_m = shader_state.dim_m; + let dim_n = shader_state.dim_n; + + gpu!( + scope, + range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each( + |res_idx_m, scope| { + gpu!( + scope, + range(0u32, shader.config.tile_size_n as u32, shader.unroll).for_each( + |res_idx_n, scope| { + let row_index = scope.create_local(Elem::UInt); + let col_index = scope.create_local(Elem::UInt); + gpu!(scope, row_index = row + res_idx_m); + gpu!(scope, col_index = col + res_idx_n); + + if shader.bounds_check_required { + let within_output = scope.create_local(Elem::Bool); + let within_output_tmp = scope.create_local(Elem::Bool); + gpu!(scope, within_output = row_index < dim_m); + gpu!(scope, within_output_tmp = col_index < dim_n); + gpu!(scope, within_output = within_output && within_output_tmp); + gpu!(scope, if(within_output).then(|scope|{ + write_inner( + scope, + shader, + shader_state, + res_idx_m, + res_idx_n, + row_index, + col_index, + ); + })); + } else { + write_inner( + scope, + shader, + shader_state, + res_idx_m, + res_idx_n, + row_index, + col_index, + ) + } + } + ) + ); + } + ) + ); +} + +#[allow(clippy::too_many_arguments)] +fn write_inner( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, + res_idx_m: Variable, + res_idx_n: Variable, + row_index: Variable, + col_index: Variable, +) { + let offset_output = shader_state.offset_output; + let out = shader_state.out; + let out_stride_row = shader_state.out_stride_row; + let out_stride_col = shader_state.out_stride_col; + let results = shader_state.results; + + let elem = results.item().elem(); + let results_position = scope.create_local(Elem::UInt); + gpu!( + scope, + results_position = res_idx_m * shader.config.tile_size_n + ); + gpu!(scope, results_position += res_idx_n); + + let result = scope.create_local(elem); + gpu!(scope, result = results[results_position]); + + let output_position = scope.create_local(Elem::UInt); + gpu!(scope, row_index *= out_stride_row); + gpu!(scope, col_index *= out_stride_col); + gpu!(scope, output_position = row_index + col_index); + gpu!(scope, output_position += offset_output); + + gpu!(scope, out[output_position] = result); +} From ea824144e43bbb805ce57b90478e3e965779ea47 Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 19 Mar 2024 11:53:05 -0400 Subject: [PATCH 18/25] remove assign vec4 --- .../burn-jit/src/codegen/dialect/gpu/macros.rs | 11 ++++++++--- .../src/codegen/dialect/gpu/operation.rs | 11 ----------- .../src/codegen/dialect/gpu/vectorization.rs | 18 +----------------- crates/burn-jit/src/fusion/tracing/builder.rs | 7 ------- .../src/kernel/matmul/tiling2d_shader/base.rs | 2 +- crates/burn-wgpu/src/compiler/wgsl/compiler.rs | 7 ------- .../src/compiler/wgsl/instructions.rs | 10 ---------- 7 files changed, 10 insertions(+), 56 deletions(-) diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index 99642cc1e2..25feb762a4 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -295,9 +295,14 @@ macro_rules! gpu { }; // out = vec4(a, b, c, d) ($scope:expr, $out:ident = vec4($a:ident,$b:ident,$c:ident,$d:ident)) => { - $scope.register($crate::codegen::dialect::gpu::Operator::AssignVec4( - $crate::codegen::dialect::gpu::AssignVec4Operator{a:$a,b:$b,c:$c,d:$d,out:$out} - )); + let i = $scope.zero(Elem::UInt); + gpu!($scope, $out[i] = $a); + gpu!($scope, i = i + 1u32); + gpu!($scope, $out[i] = $b); + gpu!($scope, i = i + 1u32); + gpu!($scope, $out[i] = $c); + gpu!($scope, i = i + 1u32); + gpu!($scope, $out[i] = $d); }; // out = input ($scope:expr, $out:ident = $input:ident) => { diff --git a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs index 5df1491b68..fc9a3315be 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs @@ -59,7 +59,6 @@ pub enum Operator { BitwiseXor(BinaryOperator), ShiftLeft(BinaryOperator), ShiftRight(BinaryOperator), - AssignVec4(AssignVec4Operator), } /// All metadata that can be access in a shader. @@ -108,16 +107,6 @@ pub struct ClampOperator { pub out: Variable, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct AssignVec4Operator { - pub a: Variable, - pub b: Variable, - pub c: Variable, - pub d: Variable, - pub out: Variable, -} - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[allow(missing_docs)] pub struct ReadGlobalOperator { diff --git a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs index ad4e00332c..9e095cd160 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs @@ -1,7 +1,4 @@ -use super::{ - AssignVec4Operator, BinaryOperator, ClampOperator, Item, Operation, Operator, UnaryOperator, - Variable, -}; +use super::{BinaryOperator, ClampOperator, Item, Operation, Operator, UnaryOperator, Variable}; /// Define a vectorization scheme. #[allow(dead_code)] @@ -83,7 +80,6 @@ impl Operator { Operator::BitwiseXor(op) => Operator::BitwiseXor(op.vectorize(vectorization)), Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)), Operator::ShiftRight(op) => Operator::ShiftRight(op.vectorize(vectorization)), - Operator::AssignVec4(op) => Operator::AssignVec4(op.vectorize(vectorization)), } } } @@ -118,18 +114,6 @@ impl ClampOperator { } } -impl AssignVec4Operator { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - Self { - a: self.a, - b: self.b, - c: self.c, - d: self.d, - out: self.out.vectorize(vectorization), - } - } -} - impl Variable { pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Self { match self { diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index aaf922d3bf..16c9422e34 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -356,13 +356,6 @@ impl TraceBuilder { &mut local_tensor_ids_input, &mut local_tensor_ids_output, ), - gpu::Operator::AssignVec4(op) => { - mark(&op.a, &mut local_tensor_ids_input); - mark(&op.b, &mut local_tensor_ids_input); - mark(&op.c, &mut local_tensor_ids_input); - mark(&op.d, &mut local_tensor_ids_input); - mark(&op.out, &mut local_tensor_ids_output); - } }, Operation::Procedure(proc) => { match proc { diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs index da94b887f1..9af7ed2391 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs @@ -49,7 +49,7 @@ impl MatmulTiling2dShader { gpu!( scope, range(0u32, shader_state.n_loops).for_each(|i, scope| { - // Equivalent of looping from 0 to K with steps block_size_k + // From 0 to K with steps block_size_k let k = shader_state.k; gpu!(scope, k = i * block_size_k); diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index b8f10fa4f8..02dacf9d4c 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -552,13 +552,6 @@ impl WgslCompiler { rhs: self.compile_variable(op.rhs), out: self.compile_variable(op.out), }, - gpu::Operator::AssignVec4(op) => wgsl::Instruction::AssignVec4 { - a: self.compile_variable(op.a), - b: self.compile_variable(op.b), - c: self.compile_variable(op.c), - d: self.compile_variable(op.d), - out: self.compile_variable(op.out), - }, } } diff --git a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs index a1e9d2cdbe..e0c6826534 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs @@ -214,13 +214,6 @@ pub enum Instruction { rhs: Variable, out: Variable, }, - AssignVec4 { - a: Variable, - b: Variable, - c: Variable, - d: Variable, - out: Variable, - }, } impl Display for Instruction { @@ -475,9 +468,6 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{ Instruction::ShiftRight { lhs, rhs, out } => { f.write_fmt(format_args!("{out} = {lhs} >> {rhs};\n")) } - Instruction::AssignVec4 { a, b, c, d, out } => { - f.write_fmt(format_args!("{out} = vec4({a}, {b}, {c}, {d});\n")) - } } } } From 8b504f7dd042cd1e02cdedd5683669c3430062f2 Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 19 Mar 2024 12:21:46 -0400 Subject: [PATCH 19/25] variable declarations above loops --- .../matmul/tiling2d_shader/computation.rs | 19 +++-- .../tiling2d_shader/load_shared_memory.rs | 71 ++++++++++-------- .../tiling2d_shader/shader_information.rs | 53 +++++++------ .../matmul/tiling2d_shader/write_output.rs | 74 ++++++++++++------- 4 files changed, 124 insertions(+), 93 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs index 7d48ef824b..47ca91e561 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs @@ -20,19 +20,28 @@ pub fn computation_loop( let block_size_n: Variable = shader.config.block_size_n.into(); let elem = results.item().elem(); + let lhs_sm_position = scope.create_local(Elem::UInt); + let rhs_sm_position = scope.create_local(Elem::UInt); + + let registered_m = scope.create_local(elem); + let registered_n = scope.create_local(elem); + + let multiplied = scope.create_local(elem); + let results_position = scope.create_local(Elem::UInt); + let results_before = scope.create_local(elem); + let results_after = scope.create_local(elem); + gpu!( scope, range(0u32, shader.config.block_size_k as u32, shader.unroll).for_each( |dot_index, scope| { // Load a subcolumn of values from lhs - let lhs_sm_position = scope.create_local(Elem::UInt); gpu!(scope, lhs_sm_position = thread_row / 4u32); gpu!(scope, lhs_sm_position *= block_size_k); gpu!(scope, lhs_sm_position += dot_index); gpu!(scope, register_m = shared_lhs[lhs_sm_position]); // Load a subrow of values from rhs - let rhs_sm_position = scope.create_local(Elem::UInt); gpu!(scope, rhs_sm_position = dot_index * block_size_n); gpu!(scope, rhs_sm_position += thread_col); gpu!(scope, rhs_sm_position = rhs_sm_position / 4u32); @@ -46,15 +55,11 @@ pub fn computation_loop( scope, range(0u32, shader.config.tile_size_n as u32, shader.unroll) .for_each(|res_idx_n, scope| { - let registered_m = scope.create_local(elem); - let registered_n = scope.create_local(elem); gpu!(scope, registered_m = register_m[res_idx_m]); gpu!(scope, registered_n = register_n[res_idx_n]); - let multiplied = scope.create_local(elem); gpu!(scope, multiplied = registered_m * registered_n); - let results_position = scope.create_local(Elem::UInt); gpu!( scope, results_position = @@ -62,9 +67,7 @@ pub fn computation_loop( ); gpu!(scope, results_position += res_idx_n); - let results_before = scope.create_local(elem); gpu!(scope, results_before = results[results_position]); - let results_after = scope.create_local(elem); gpu!(scope, results_after = results_before + multiplied); gpu!(scope, results[results_position] = results_after); diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs index 6bbc16b751..896b2f2424 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs @@ -74,20 +74,26 @@ fn load_shared_memory_with_bound_check( let block_size_n: Variable = shader.config.block_size_n.into(); let elem = input.item().elem(); + let current = scope.create_local(Elem::UInt); + let aligned_with_shared_memory = scope.create_local(Elem::Bool); + let sm_position = scope.create_local(Elem::UInt); + let within_input = scope.create_local(Elem::Bool); + let current_with_k = scope.create_local(Elem::UInt); + let remain_at_least_1 = scope.create_local(Elem::Bool); + let read_condition = scope.create_local(Elem::Bool); + let val_vec4 = scope.create_local(shared_memory.item()); + gpu!( scope, range(0_u32, 4u32, shader.unroll).for_each(|j, scope| { - let current = scope.create_local(Elem::UInt); gpu!(scope, current = thread_idx_1 + j); - let aligned_with_shared_memory = scope.create_local(Elem::Bool); gpu!(scope, aligned_with_shared_memory = current < block_size_k); // To avoid overwriting following row in shared memory gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ // Position in shared memory - let sm_position = scope.create_local(Elem::UInt); match input_identifier { InputIdentifier::Lhs => { gpu!(scope, sm_position = thread_idx_2 / 4u32); @@ -102,53 +108,53 @@ fn load_shared_memory_with_bound_check( } // To pad with zeros if outside lhs - let within_input = scope.create_local(Elem::Bool); - let current_with_k = scope.create_local(Elem::UInt); - let remain_at_least_1 = scope.create_local(Elem::Bool); - let read_condition = scope.create_local(Elem::Bool); gpu!(scope, current_with_k = current + k); gpu!(scope, within_input = current_with_k < dim_k); gpu!(scope, remain_at_least_1 = remain >= 1u32); gpu!(scope, read_condition = within_input && remain_at_least_1); gpu!(scope, if(read_condition).then(|scope| { - let position_0 = scope.create_local(Elem::UInt); - gpu!(scope, position_0 = k + current); - gpu!(scope, position_0 *= stride_1); let tmp = scope.create_local(Elem::UInt); - gpu!(scope, tmp = thread_idx_2 * stride_2); - gpu!(scope, position_0 += tmp); - gpu!(scope, position_0 += input_offset); + let position_0 = scope.create_local(Elem::UInt); let position_1 = scope.create_local(Elem::UInt); let position_2 = scope.create_local(Elem::UInt); let position_3 = scope.create_local(Elem::UInt); - gpu!(scope, position_1 = position_0 + stride_2); - gpu!(scope, position_2 = position_1 + stride_2); - gpu!(scope, position_3 = position_2 + stride_2); + let remain_n = scope.create_local(Elem::Bool); let val_0 = scope.zero(elem); let val_1 = scope.zero(elem); let val_2 = scope.zero(elem); let val_3 = scope.zero(elem); - let remain_n = scope.create_local(Elem::Bool); + gpu!(scope, position_0 = k + current); + gpu!(scope, position_0 *= stride_1); + gpu!(scope, tmp = thread_idx_2 * stride_2); + gpu!(scope, position_0 += tmp); + gpu!(scope, position_0 += input_offset); + gpu!(scope, position_1 = position_0 + stride_2); + gpu!(scope, position_2 = position_1 + stride_2); + gpu!(scope, position_3 = position_2 + stride_2); + gpu!(scope, remain_n = remain >= 4u32); gpu!(scope, if(remain_n).then(|scope|{ gpu!(scope, val_0 = input[position_0]); gpu!(scope, val_1 = input[position_1]); gpu!(scope, val_2 = input[position_2]); gpu!(scope, val_3 = input[position_3]); + }).else(|scope|{ gpu!(scope, remain_n = remain == 3u32); gpu!(scope, if(remain_n).then(|scope|{ gpu!(scope, val_0 = input[position_0]); gpu!(scope, val_1 = input[position_1]); gpu!(scope, val_2 = input[position_2]); + }).else(|scope|{ gpu!(scope, remain_n = remain == 2u32); gpu!(scope, if(remain_n).then(|scope|{ gpu!(scope, val_0 = input[position_0]); gpu!(scope, val_1 = input[position_1]); + }).else(|scope|{ gpu!(scope, remain_n = remain == 1u32); gpu!(scope, if(remain_n).then(|scope|{ @@ -158,12 +164,11 @@ fn load_shared_memory_with_bound_check( })); })); - let val_vec4 = scope.create_local(shared_memory.item()); gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); gpu!(scope, shared_memory[sm_position] = val_vec4); + }).else(|scope|{ let val_0 = scope.zero(elem); - let val_vec4 = scope.create_local(shared_memory.item()); gpu!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0)); gpu!(scope, shared_memory[sm_position] = val_vec4); })); @@ -206,19 +211,31 @@ fn load_shared_memory_no_bound_check( let block_size_n: Variable = shader.config.block_size_n.into(); let elem = input.item().elem(); + let current = scope.create_local(Elem::UInt); + let aligned_with_shared_memory = scope.create_local(Elem::Bool); + let sm_position = scope.create_local(Elem::UInt); + + let tmp = scope.create_local(Elem::UInt); + let position_0 = scope.create_local(Elem::UInt); + let position_1 = scope.create_local(Elem::UInt); + let position_2 = scope.create_local(Elem::UInt); + let position_3 = scope.create_local(Elem::UInt); + let val_0 = scope.create_local(elem); + let val_1 = scope.create_local(elem); + let val_2 = scope.create_local(elem); + let val_3 = scope.create_local(elem); + let val_vec4 = scope.create_local(shared_memory.item()); + gpu!( scope, range(0_u32, 4u32, shader.unroll).for_each(|j, scope| { - let current = scope.create_local(Elem::UInt); gpu!(scope, current = thread_idx_1 + j); - let aligned_with_shared_memory = scope.create_local(Elem::Bool); gpu!(scope, aligned_with_shared_memory = current < block_size_k); // To avoid overwriting following row in shared memory gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ - let sm_position = scope.create_local(Elem::UInt); match input_identifier { InputIdentifier::Lhs => { gpu!(scope, sm_position = thread_idx_2 / 4u32); @@ -232,30 +249,20 @@ fn load_shared_memory_no_bound_check( } } - let position_0 = scope.create_local(Elem::UInt); gpu!(scope, position_0 = k + current); gpu!(scope, position_0 *= stride_1); - let tmp = scope.create_local(Elem::UInt); gpu!(scope, tmp = thread_idx_2 * stride_2); gpu!(scope, position_0 += tmp); gpu!(scope, position_0 += input_offset); - let position_1 = scope.create_local(Elem::UInt); - let position_2 = scope.create_local(Elem::UInt); - let position_3 = scope.create_local(Elem::UInt); gpu!(scope, position_1 = position_0 + stride_2); gpu!(scope, position_2 = position_1 + stride_2); gpu!(scope, position_3 = position_2 + stride_2); - let val_0 = scope.create_local(elem); - let val_1 = scope.create_local(elem); - let val_2 = scope.create_local(elem); - let val_3 = scope.create_local(elem); gpu!(scope, val_0 = input[position_0]); gpu!(scope, val_1 = input[position_1]); gpu!(scope, val_2 = input[position_2]); gpu!(scope, val_3 = input[position_3]); - let val_vec4 = scope.create_local(shared_memory.item()); gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); gpu!(scope, shared_memory[sm_position] = val_vec4); })); diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs index bc22152d88..fca13cebed 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs @@ -27,16 +27,16 @@ pub(crate) fn gather_shader_information( // Shapes let rank = Variable::Rank; - let ultimate_dim = scope.create_local(Elem::UInt); - let penultimate_dim = scope.create_local(Elem::UInt); - gpu!(scope, ultimate_dim = rank - 1u32); - gpu!(scope, penultimate_dim = rank - 2u32); + let last_dim = scope.create_local(Elem::UInt); + let second_to_last_dim = scope.create_local(Elem::UInt); let dim_m = scope.create_local(Elem::UInt); let dim_k = scope.create_local(Elem::UInt); let dim_n = scope.create_local(Elem::UInt); - gpu!(scope, dim_m = shape(lhs, penultimate_dim)); - gpu!(scope, dim_k = shape(lhs, ultimate_dim)); - gpu!(scope, dim_n = shape(rhs, ultimate_dim)); + gpu!(scope, last_dim = rank - 1u32); + gpu!(scope, second_to_last_dim = rank - 2u32); + gpu!(scope, dim_m = shape(lhs, second_to_last_dim)); + gpu!(scope, dim_k = shape(lhs, last_dim)); + gpu!(scope, dim_n = shape(rhs, last_dim)); // Strides let lhs_stride_row = scope.create_local(Elem::UInt); @@ -45,28 +45,28 @@ pub(crate) fn gather_shader_information( let rhs_stride_col = scope.create_local(Elem::UInt); let out_stride_row = scope.create_local(Elem::UInt); let out_stride_col = scope.create_local(Elem::UInt); - gpu!(scope, lhs_stride_row = stride(lhs, penultimate_dim)); - gpu!(scope, lhs_stride_col = stride(lhs, ultimate_dim)); - gpu!(scope, rhs_stride_row = stride(rhs, penultimate_dim)); - gpu!(scope, rhs_stride_col = stride(rhs, ultimate_dim)); - gpu!(scope, out_stride_row = stride(out, penultimate_dim)); - gpu!(scope, out_stride_col = stride(out, ultimate_dim)); + gpu!(scope, lhs_stride_row = stride(lhs, second_to_last_dim)); + gpu!(scope, lhs_stride_col = stride(lhs, last_dim)); + gpu!(scope, rhs_stride_row = stride(rhs, second_to_last_dim)); + gpu!(scope, rhs_stride_col = stride(rhs, last_dim)); + gpu!(scope, out_stride_row = stride(out, second_to_last_dim)); + gpu!(scope, out_stride_col = stride(out, last_dim)); // Workgroup offset let skip_row = scope.create_local(Elem::UInt); + let skip_col = scope.create_local(Elem::UInt); let workgroup_id_x = Variable::WorkgroupIdX; + let workgroup_id_y = Variable::WorkgroupIdY; gpu!(scope, skip_row = workgroup_id_x); gpu!(scope, skip_row *= block_size_m); - let skip_col = scope.create_local(Elem::UInt); - let workgroup_id_y = Variable::WorkgroupIdY; gpu!(scope, skip_col = workgroup_id_y); gpu!(scope, skip_col *= block_size_n); // Position of the first element of the thread, relative to the block let thread_row = scope.create_local(Elem::UInt); + let thread_col = scope.create_local(Elem::UInt); gpu!(scope, thread_row = local_idx / n_threads_per_row); gpu!(scope, thread_row *= tile_size_m); - let thread_col = scope.create_local(Elem::UInt); gpu!(scope, thread_col = local_idx % n_threads_per_row); gpu!(scope, thread_col *= tile_size_n); @@ -89,30 +89,29 @@ pub(crate) fn gather_shader_information( gpu!(scope, offset_output = offset_output * batch); // Batch offset for the lhs & rhs matrices. + let stride_lhs = scope.create_local(Elem::UInt); + let stride_rhs = scope.create_local(Elem::UInt); + let stride_output = scope.create_local(Elem::UInt); + let shape_lhs = scope.create_local(Elem::UInt); + let shape_rhs = scope.create_local(Elem::UInt); + let tmp = scope.create_local(Elem::UInt); + let tmp_lhs = scope.create_local(Elem::UInt); + let tmp_rhs = scope.create_local(Elem::UInt); gpu!(scope, batch_dims = rank - 2u32); gpu!( scope, range(0u32, batch_dims).for_each(|b, scope| { - let stride_lhs = scope.create_local(Elem::UInt); - let stride_rhs = scope.create_local(Elem::UInt); - let stride_output = scope.create_local(Elem::UInt); - let shape_lhs = scope.create_local(Elem::UInt); - let shape_rhs = scope.create_local(Elem::UInt); - gpu!(scope, stride_lhs = stride(lhs, b)); gpu!(scope, stride_rhs = stride(rhs, b)); gpu!(scope, stride_output = stride(out, b)); gpu!(scope, shape_lhs = shape(lhs, b)); gpu!(scope, shape_rhs = shape(rhs, b)); - let tmp = scope.create_local(Elem::UInt); gpu!(scope, tmp = offset_output / stride_output); - let tmp_lhs = scope.create_local(Elem::UInt); gpu!(scope, tmp_lhs = tmp % shape_lhs); gpu!(scope, tmp_lhs = tmp_lhs * stride_lhs); gpu!(scope, offset_lhs += tmp_lhs); - let tmp_rhs = scope.create_local(Elem::UInt); gpu!(scope, tmp_rhs = tmp % shape_rhs); gpu!(scope, tmp_rhs = tmp_rhs * stride_rhs); gpu!(scope, offset_rhs += tmp_rhs); @@ -134,7 +133,9 @@ pub(crate) fn gather_shader_information( shader.config.block_size_k as u32 * shader.config.block_size_n as u32 / 4u32, ); + // Calculate exact number of loop iterations let n_loops = scope.create_local(Elem::UInt); + let k = scope.create_local(Elem::UInt); if shader.bounds_check_required { let dim_k_float = scope.create_local(elem); let block_size_k_float = scope.create_local(elem); @@ -148,8 +149,6 @@ pub(crate) fn gather_shader_information( gpu!(scope, n_loops = dim_k / block_size_k); } - let k = scope.create_local(Elem::UInt); - Tiling2dState { n_loops, k, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs index 6884848579..ea09a0c9cf 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs @@ -10,28 +10,32 @@ pub fn write_to_output( ) { let row = shader_state.row; let col = shader_state.col; - let dim_m = shader_state.dim_m; - let dim_n = shader_state.dim_n; - gpu!( - scope, - range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each( - |res_idx_m, scope| { - gpu!( - scope, - range(0u32, shader.config.tile_size_n as u32, shader.unroll).for_each( - |res_idx_n, scope| { - let row_index = scope.create_local(Elem::UInt); - let col_index = scope.create_local(Elem::UInt); - gpu!(scope, row_index = row + res_idx_m); - gpu!(scope, col_index = col + res_idx_n); + let row_index = scope.create_local(Elem::UInt); + let col_index = scope.create_local(Elem::UInt); + + if shader.bounds_check_required { + let dim_m = shader_state.dim_m; + let dim_n = shader_state.dim_n; + + let within_output = scope.create_local(Elem::Bool); + let within_output_tmp = scope.create_local(Elem::Bool); + + gpu!( + scope, + range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each( + |res_idx_m, scope| { + gpu!( + scope, + range(0u32, shader.config.tile_size_n as u32, shader.unroll).for_each( + |res_idx_n, scope| { + gpu!(scope, row_index = row + res_idx_m); + gpu!(scope, col_index = col + res_idx_n); - if shader.bounds_check_required { - let within_output = scope.create_local(Elem::Bool); - let within_output_tmp = scope.create_local(Elem::Bool); gpu!(scope, within_output = row_index < dim_m); gpu!(scope, within_output_tmp = col_index < dim_n); gpu!(scope, within_output = within_output && within_output_tmp); + gpu!(scope, if(within_output).then(|scope|{ write_inner( scope, @@ -43,7 +47,24 @@ pub fn write_to_output( col_index, ); })); - } else { + } + ) + ); + } + ) + ); + } else { + gpu!( + scope, + range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each( + |res_idx_m, scope| { + gpu!( + scope, + range(0u32, shader.config.tile_size_n as u32, shader.unroll).for_each( + |res_idx_n, scope| { + gpu!(scope, row_index = row + res_idx_m); + gpu!(scope, col_index = col + res_idx_n); + write_inner( scope, shader, @@ -54,12 +75,12 @@ pub fn write_to_output( col_index, ) } - } - ) - ); - } - ) - ); + ) + ); + } + ) + ); + } } #[allow(clippy::too_many_arguments)] @@ -80,16 +101,17 @@ fn write_inner( let elem = results.item().elem(); let results_position = scope.create_local(Elem::UInt); + let result = scope.create_local(elem); + let output_position = scope.create_local(Elem::UInt); + gpu!( scope, results_position = res_idx_m * shader.config.tile_size_n ); gpu!(scope, results_position += res_idx_n); - let result = scope.create_local(elem); gpu!(scope, result = results[results_position]); - let output_position = scope.create_local(Elem::UInt); gpu!(scope, row_index *= out_stride_row); gpu!(scope, col_index *= out_stride_col); gpu!(scope, output_position = row_index + col_index); From 947c47300c6d70b54a9e91258d4430cef514c1b1 Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 19 Mar 2024 12:22:06 -0400 Subject: [PATCH 20/25] fmt --- crates/burn-jit/src/kernel/matmul/tiling2d.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index 5f663e4f13..58fa578b49 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -13,7 +13,10 @@ use crate::{ use std::marker::PhantomData; use super::{ - padding::{crop, pad_round, PaddingOutput}, shape_out, tiling2d_launch_options, tiling2d_shader::MatmulTiling2dShader, Tiling2dConfig + padding::{crop, pad_round, PaddingOutput}, + shape_out, tiling2d_launch_options, + tiling2d_shader::MatmulTiling2dShader, + Tiling2dConfig, }; #[derive(new, Debug)] From 4c97c7566b427242fc53a4bd2bb542e149e475eb Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 19 Mar 2024 12:25:16 -0400 Subject: [PATCH 21/25] clippy --- crates/burn-jit/src/kernel/matmul/tiling2d.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index 58fa578b49..052a88aa99 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -184,7 +184,7 @@ fn check_bound_requirement( rhs_shape: &Shape, config: &Tiling2dConfig, ) -> bool { - !(lhs_shape.dims[D - 2] % config.block_size_m == 0) - || !(lhs_shape.dims[D - 1] % config.block_size_k == 0) - || !(rhs_shape.dims[D - 1] % config.block_size_n == 0) + lhs_shape.dims[D - 2] % config.block_size_m != 0 + || lhs_shape.dims[D - 1] % config.block_size_k != 0 + || rhs_shape.dims[D - 1] % config.block_size_n != 0 } From 66076dbe0f721d84c368b210fa87a12b2aac8ee2 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 20 Mar 2024 15:32:37 -0400 Subject: [PATCH 22/25] Fix autotune + unroll --- crates/burn-jit/src/kernel/matmul/base.rs | 5 +---- crates/burn-jit/src/kernel/matmul/tiling2d.rs | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 2673aaaa20..aebba270a1 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -100,10 +100,7 @@ pub enum MatmulStrategy { #[cfg(feature = "autotune")] impl Default for MatmulStrategy { fn default() -> Self { - MatmulStrategy::Simple { - grid_x: 32, - grid_y: 32, - } + MatmulStrategy::Autotune } } diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index 052a88aa99..d559097431 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -44,7 +44,7 @@ impl DynamicKernelSource for MatmulTiling2dEagerKernel { variables: gpu::BinaryOperator { lhs, rhs, out }, config: self.config.clone(), bounds_check_required: self.bounds_check_required, - unroll: false, + unroll: true, } .expand(&mut scope); From 7a6299a65bee1d9c18011071780e87f945b58420 Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 21 Mar 2024 08:50:30 -0400 Subject: [PATCH 23/25] move val --- .../tiling2d_shader/load_shared_memory.rs | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs index 896b2f2424..fd762d8f36 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs @@ -83,6 +83,19 @@ fn load_shared_memory_with_bound_check( let read_condition = scope.create_local(Elem::Bool); let val_vec4 = scope.create_local(shared_memory.item()); + let tmp = scope.create_local(Elem::UInt); + let position_0 = scope.create_local(Elem::UInt); + let position_1 = scope.create_local(Elem::UInt); + let position_2 = scope.create_local(Elem::UInt); + let position_3 = scope.create_local(Elem::UInt); + let remain_n = scope.create_local(Elem::Bool); + + let val_0 = scope.create_local(elem); + let val_1 = scope.create_local(elem); + let val_2 = scope.create_local(elem); + let val_3 = scope.create_local(elem); + let zero: Variable = 0u32.into(); + gpu!( scope, range(0_u32, 4u32, shader.unroll).for_each(|j, scope| { @@ -114,18 +127,6 @@ fn load_shared_memory_with_bound_check( gpu!(scope, read_condition = within_input && remain_at_least_1); gpu!(scope, if(read_condition).then(|scope| { - let tmp = scope.create_local(Elem::UInt); - let position_0 = scope.create_local(Elem::UInt); - let position_1 = scope.create_local(Elem::UInt); - let position_2 = scope.create_local(Elem::UInt); - let position_3 = scope.create_local(Elem::UInt); - let remain_n = scope.create_local(Elem::Bool); - - let val_0 = scope.zero(elem); - let val_1 = scope.zero(elem); - let val_2 = scope.zero(elem); - let val_3 = scope.zero(elem); - gpu!(scope, position_0 = k + current); gpu!(scope, position_0 *= stride_1); gpu!(scope, tmp = thread_idx_2 * stride_2); @@ -148,17 +149,23 @@ fn load_shared_memory_with_bound_check( gpu!(scope, val_0 = input[position_0]); gpu!(scope, val_1 = input[position_1]); gpu!(scope, val_2 = input[position_2]); + gpu!(scope, val_3 = zero); }).else(|scope|{ gpu!(scope, remain_n = remain == 2u32); gpu!(scope, if(remain_n).then(|scope|{ gpu!(scope, val_0 = input[position_0]); gpu!(scope, val_1 = input[position_1]); + gpu!(scope, val_2 = zero); + gpu!(scope, val_3 = zero); }).else(|scope|{ gpu!(scope, remain_n = remain == 1u32); gpu!(scope, if(remain_n).then(|scope|{ gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = zero); + gpu!(scope, val_2 = zero); + gpu!(scope, val_3 = zero); })); })); })); @@ -168,7 +175,7 @@ fn load_shared_memory_with_bound_check( gpu!(scope, shared_memory[sm_position] = val_vec4); }).else(|scope|{ - let val_0 = scope.zero(elem); + gpu!(scope, val_0 = zero); gpu!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0)); gpu!(scope, shared_memory[sm_position] = val_vec4); })); From b3af9650271674cc51e7b8eb1e3868317132de3c Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 21 Mar 2024 09:01:10 -0400 Subject: [PATCH 24/25] clippy --- crates/burn-jit/src/kernel/matmul/base.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index aebba270a1..cf2e4699df 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -80,6 +80,7 @@ impl Default for Tiling2dConfig { } /// The strategy to be used when launching a matmul kernel. +#[derive(Default)] pub enum MatmulStrategy { /// A simple kernel will be used with memory coalescing optimization. Simple { @@ -94,15 +95,12 @@ pub enum MatmulStrategy { Tiling2dPadded(Tiling2dConfig), #[cfg(feature = "autotune")] /// Using autotune to chose the best kernel based on runtime information. + #[default] Autotune, } #[cfg(feature = "autotune")] -impl Default for MatmulStrategy { - fn default() -> Self { - MatmulStrategy::Autotune - } -} + #[cfg(not(feature = "autotune"))] impl Default for MatmulStrategy { From 6237478468c47bdfe08595b85dfd412996128557 Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 21 Mar 2024 09:06:20 -0400 Subject: [PATCH 25/25] fmt --- crates/burn-jit/src/kernel/matmul/base.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index cf2e4699df..537d732dfb 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -100,8 +100,6 @@ pub enum MatmulStrategy { } #[cfg(feature = "autotune")] - - #[cfg(not(feature = "autotune"))] impl Default for MatmulStrategy { fn default() -> Self {