From dd699a90a2b8bf866ec0719ae040d6dec57eaead Mon Sep 17 00:00:00 2001 From: Louis Fortier-Dubois Date: Fri, 22 Mar 2024 08:26:32 -0400 Subject: [PATCH] Migrate/jit/matmul tiling 2d (#1472) * refactor matmul files * wip refactor matmul * everything is memco * support local arrays * advancing tiling2d * advancing tiling2d * advancing tiling2d * tiling2d finished but buggy * configurable unrolling * not bugged * fails on unroll * stupid break * tiling2d no assumption works * clippy * bounds check as bool * lhs rhs as enum * tiling 2d major refactor * remove assign vec4 * variable declarations above loops * fmt * clippy * Fix autotune + unroll * move val * clippy * fmt --------- Co-authored-by: nathaniel --- crates/burn-jit/src/codegen/compiler.rs | 2 + .../src/codegen/dialect/gpu/branch.rs | 12 + .../src/codegen/dialect/gpu/macros.rs | 21 +- .../src/codegen/dialect/gpu/operation.rs | 1 + .../burn-jit/src/codegen/dialect/gpu/scope.rs | 26 +- .../src/codegen/dialect/gpu/variable.rs | 3 + .../src/codegen/dialect/gpu/vectorization.rs | 7 + crates/burn-jit/src/fusion/tracing/builder.rs | 5 + crates/burn-jit/src/kernel/matmul/base.rs | 137 ++++++++- crates/burn-jit/src/kernel/matmul/mod.rs | 15 +- .../kernel/matmul/{tiling2d => }/padding.rs | 0 .../matmul/{mem_coalescing.rs => simple.rs} | 34 +-- crates/burn-jit/src/kernel/matmul/tiling2d.rs | 190 ++++++++++++ .../src/kernel/matmul/tiling2d/base.rs | 91 ------ .../src/kernel/matmul/tiling2d/mod.rs | 14 - .../src/kernel/matmul/tiling2d/unpadded.rs | 74 ----- .../src/kernel/matmul/tiling2d/vec4.rs | 49 --- .../src/kernel/matmul/tiling2d_shader/base.rs | 68 +++++ .../matmul/tiling2d_shader/computation.rs | 82 ++++++ .../tiling2d_shader/load_shared_memory.rs | 278 ++++++++++++++++++ .../src/kernel/matmul/tiling2d_shader/mod.rs | 11 + .../tiling2d_shader/shader_information.rs | 180 ++++++++++++ .../matmul/tiling2d_shader/write_output.rs | 121 ++++++++ .../burn-jit/src/kernel/matmul/tune/base.rs | 57 ++-- crates/burn-jit/src/tests/matmul.rs | 68 ++++- crates/burn-wgpu/src/compiler/wgsl/base.rs | 6 + .../burn-wgpu/src/compiler/wgsl/compiler.rs | 20 ++ .../src/compiler/wgsl/instructions.rs | 7 + crates/burn-wgpu/src/compiler/wgsl/shader.rs | 47 ++- 29 files changed, 1296 insertions(+), 330 deletions(-) rename crates/burn-jit/src/kernel/matmul/{tiling2d => }/padding.rs (100%) rename crates/burn-jit/src/kernel/matmul/{mem_coalescing.rs => simple.rs} (89%) create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d.rs delete mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d/base.rs delete mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d/mod.rs delete mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d/unpadded.rs delete mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d/vec4.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/mod.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs create mode 100644 crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs diff --git a/crates/burn-jit/src/codegen/compiler.rs b/crates/burn-jit/src/codegen/compiler.rs index c06cd62be1..250fecfe92 100644 --- a/crates/burn-jit/src/codegen/compiler.rs +++ b/crates/burn-jit/src/codegen/compiler.rs @@ -22,4 +22,6 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug { fn compile(shader: gpu::ComputeShader) -> Self::Representation; /// The size of the given element in bytes. fn elem_size(elem: gpu::Elem) -> usize; + /// The maximal size of a shared memory + fn max_shared_memory_size() -> usize; } diff --git a/crates/burn-jit/src/codegen/dialect/gpu/branch.rs b/crates/burn-jit/src/codegen/dialect/gpu/branch.rs index 1a617a8dd1..fdc3ea9ba4 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/branch.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/branch.rs @@ -119,3 +119,15 @@ impl Loop { parent_scope.register(Branch::Loop(op)); } } + +#[allow(missing_docs)] +pub struct UnrolledRangeLoop; + +impl UnrolledRangeLoop { + /// Registers an unrolled range loop to the given scope. + pub fn register(scope: &mut Scope, start: u32, end: u32, func: F) { + for i in start..end { + func(i.into(), scope); + } + } +} diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index 7f6dac0a96..25feb762a4 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -293,6 +293,17 @@ macro_rules! gpu { gpu!(unary $input, $out) )); }; + // out = vec4(a, b, c, d) + ($scope:expr, $out:ident = vec4($a:ident,$b:ident,$c:ident,$d:ident)) => { + let i = $scope.zero(Elem::UInt); + gpu!($scope, $out[i] = $a); + gpu!($scope, i = i + 1u32); + gpu!($scope, $out[i] = $b); + gpu!($scope, i = i + 1u32); + gpu!($scope, $out[i] = $c); + gpu!($scope, i = i + 1u32); + gpu!($scope, $out[i] = $d); + }; // out = input ($scope:expr, $out:ident = $input:ident) => { gpu!($scope, $out = cast($input)) @@ -326,10 +337,18 @@ macro_rules! gpu { out: $out.into(), }); }; - // range(start, end).for_each(|scope| { ... }) + // range(start, end).for_each(|i, scope| { ... }) ($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => { $crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg); }; + // range(start, end, unroll).for_each(|i, scope| { ... }) + ($scope:expr, range($start:expr, $end:expr, $unroll:expr).for_each($arg:expr)) => { + if $unroll { + $crate::codegen::dialect::gpu::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), $arg); + } else { + $crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg); + } + }; // loop(|scope| { ... }) ($scope:expr, loop($arg:expr)) => { $crate::codegen::dialect::gpu::Loop::register($scope, $arg); diff --git a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs index 43fe04d0b7..fc9a3315be 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs @@ -36,6 +36,7 @@ pub enum Operator { Tanh(UnaryOperator), Powf(BinaryOperator), Sqrt(UnaryOperator), + Ceil(UnaryOperator), Erf(UnaryOperator), Recip(UnaryOperator), Equal(BinaryOperator), diff --git a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs index dc70fafd09..585c59a601 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs @@ -20,7 +20,8 @@ pub struct Scope { pub depth: u8, pub operations: Vec, locals: Vec, - shared: Vec, + shared_memories: Vec, + local_arrays: Vec, reads_global: Vec<(Variable, ReadingStrategy, Variable)>, index_offset_with_output_layout_position: Vec, writes_global: Vec<(Variable, Variable)>, @@ -48,7 +49,8 @@ impl Scope { depth: 0, operations: Vec::new(), locals: Vec::new(), - shared: Vec::new(), + local_arrays: Vec::new(), + shared_memories: Vec::new(), reads_global: Vec::new(), index_offset_with_output_layout_position: Vec::new(), writes_global: Vec::new(), @@ -213,7 +215,8 @@ impl Scope { depth: self.depth + 1, operations: Vec::new(), locals: Vec::new(), - shared: Vec::new(), + shared_memories: Vec::new(), + local_arrays: Vec::new(), reads_global: Vec::new(), index_offset_with_output_layout_position: Vec::new(), writes_global: Vec::new(), @@ -308,7 +311,11 @@ impl Scope { } fn new_shared_index(&self) -> u16 { - self.shared.len() as u16 + self.shared_memories.len() as u16 + } + + fn new_local_array_index(&self) -> u16 { + self.local_arrays.len() as u16 } fn read_input_strategy( @@ -339,7 +346,16 @@ impl Scope { let item = item.into(); let index = self.new_shared_index(); let shared_memory = Variable::SharedMemory(index, item, shared_memory_size); - self.shared.push(shared_memory); + self.shared_memories.push(shared_memory); shared_memory } + + /// Create a local array of the given [item type](Item). + pub fn create_local_array>(&mut self, item: I, array_size: u32) -> Variable { + let item = item.into(); + let index = self.new_local_array_index(); + let local_array = Variable::LocalArray(index, item, self.depth, array_size); + self.local_arrays.push(local_array); + local_array + } } diff --git a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs index 7a49d5eea0..cd52ea4134 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs @@ -11,6 +11,7 @@ pub enum Variable { LocalScalar(u16, Elem, u8), ConstantScalar(f64, Elem), SharedMemory(u16, Item, u32), + LocalArray(u16, Item, u8, u32), Id, LocalInvocationIndex, LocalInvocationIdX, @@ -41,6 +42,7 @@ impl Variable { Variable::GlobalOutputArray(idx, _) => Some(*idx), Variable::ConstantScalar(_, _) => None, Variable::SharedMemory(idx, _, _) => Some(*idx), + Variable::LocalArray(idx, _, _, _) => Some(*idx), Variable::Id => None, Variable::LocalInvocationIndex => None, Variable::LocalInvocationIdX => None, @@ -70,6 +72,7 @@ impl Variable { Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem), Variable::ConstantScalar(_, elem) => Item::Scalar(*elem), Variable::SharedMemory(_, item, _) => *item, + Variable::LocalArray(_, item, _, _) => *item, Variable::Id => Item::Scalar(Elem::UInt), Variable::Rank => Item::Scalar(Elem::UInt), Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt), diff --git a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs index 9eb277a787..9e095cd160 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs @@ -52,6 +52,7 @@ impl Operator { Operator::Tanh(op) => Operator::Tanh(op.vectorize(vectorization)), Operator::Powf(op) => Operator::Powf(op.vectorize(vectorization)), Operator::Sqrt(op) => Operator::Sqrt(op.vectorize(vectorization)), + Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)), Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)), Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)), Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)), @@ -130,6 +131,12 @@ impl Variable { item.vectorize(vectorize), item.vectorized_size(vectorize, *size), ), + Variable::LocalArray(index, item, name, size) => Variable::LocalArray( + *index, + item.vectorize(vectorize), + *name, + item.vectorized_size(vectorize, *size), + ), Variable::ConstantScalar(_, _) => *self, Variable::GlobalScalar(_, _) => *self, Variable::Id => *self, diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index bdabf8f702..16c9422e34 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -247,6 +247,11 @@ impl TraceBuilder { &mut local_tensor_ids_input, &mut local_tensor_ids_output, ), + gpu::Operator::Ceil(op) => mark_unary( + op, + &mut local_tensor_ids_input, + &mut local_tensor_ids_output, + ), gpu::Operator::Log(op) => mark_unary( op, &mut local_tensor_ids_input, diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 3eb9425878..537d732dfb 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,24 +1,98 @@ -use crate::{tensor::JitTensor, JitElement, Runtime}; +use std::cmp::{max, min}; + +use burn_tensor::Shape; + +use crate::{compute::WorkGroup, tensor::JitTensor, Compiler, JitElement, Runtime}; use super::{ - init_matmul_output, matmul_autotune, matmul_mem_coalescing, - unpadded::matmul_tiling_2d_unpadded, vec4::matmul_tiling_2d_vec4, + init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, matmul_tiling_2d_padded, }; +#[derive(Debug, Clone)] +/// Tiling 2D parameters +pub struct Tiling2dConfig { + /// Number of invocations in x + pub grid_x: usize, + /// Number of invocations in y + pub grid_y: usize, + /// Block size along dimension of lhs + pub block_size_m: usize, + /// Block size along common dimension + pub block_size_k: usize, + /// Block size along dimension of rhs + pub block_size_n: usize, + /// Tile size along dimension of lhs + pub tile_size_m: usize, + /// Tile size along dimension of rhs + pub tile_size_n: usize, +} + +impl Tiling2dConfig { + #[allow(unused)] + fn new( + grid_x: usize, + grid_y: usize, + block_size_m: usize, + block_size_k: usize, + block_size_n: usize, + tile_size_m: usize, + tile_size_n: usize, + ) -> Self { + assert!(grid_x == f32::ceil(block_size_m as f32 / tile_size_m as f32) as usize); + assert!(grid_y == f32::ceil(block_size_n as f32 / tile_size_n as f32) as usize); + assert!( + block_size_k <= min(block_size_m, block_size_n), + "Not enough invocations to fill shared memory" + ); + assert!( + block_size_k * max(block_size_m, block_size_n) + <= ::max_shared_memory_size(), + "Shared memory limit will be busted. " + ); + assert!( + block_size_m % tile_size_m == 0 && block_size_n % tile_size_n == 0, + "Tile size must divide block size in m and n dimensions" + ); + Self { + grid_x, + grid_y, + block_size_m, + block_size_k, + block_size_n, + tile_size_m, + tile_size_n, + } + } +} + +impl Default for Tiling2dConfig { + fn default() -> Self { + Self { + grid_x: 16, + grid_y: 16, + block_size_m: 64, + block_size_k: 32, + block_size_n: 64, + tile_size_m: 4, + tile_size_n: 4, + } + } +} + /// The strategy to be used when launching a matmul kernel. #[derive(Default)] pub enum MatmulStrategy { /// A simple kernel will be used with memory coalescing optimization. Simple { - /// Grad size x + /// Number of invocations in x grid_x: usize, - /// Grad size y + /// Number of invocations in y grid_y: usize, }, /// A tiling 2d kernel will be used, with support for any matrix size without padding. - Tiling2d, + Tiling2d(Tiling2dConfig), /// A tiling 2d kernel will be used, with support for any matrix size with padding. - Tiling2dPadded, + Tiling2dPadded(Tiling2dConfig), #[cfg(feature = "autotune")] /// Using autotune to chose the best kernel based on runtime information. #[default] @@ -42,17 +116,56 @@ pub fn matmul( match strategy { MatmulStrategy::Simple { grid_x, grid_y } => { let out = init_matmul_output(&lhs, &rhs); - matmul_mem_coalescing(lhs, rhs, out, grid_x, grid_y) + matmul_simple(lhs, rhs, out, grid_x, grid_y) } - MatmulStrategy::Tiling2d => { + MatmulStrategy::Tiling2d(config) => { let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d_unpadded(lhs, rhs, out) + matmul_tiling_2d(lhs, rhs, out, config) } - MatmulStrategy::Tiling2dPadded => { + MatmulStrategy::Tiling2dPadded(config) => { let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d_vec4(lhs, rhs, out) + matmul_tiling_2d_padded(lhs, rhs, out, config) } #[cfg(feature = "autotune")] MatmulStrategy::Autotune => matmul_autotune(lhs, rhs), } } + +pub(crate) fn simple_launch_options( + lhs_shape: &Shape, + rhs_shape: &Shape, + output_shape: &Shape, + workgroup_size_x: usize, + workgroup_size_y: usize, +) -> WorkGroup { + let num_rows = lhs_shape.dims[D - 2]; + let num_cols = rhs_shape.dims[D - 1]; + + // set number of workgroups + let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; + let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; + let mut num_iter = 1; + for i in 0..D - 2 { + num_iter *= output_shape.dims[i]; + } + + WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32) +} + +pub(crate) fn tiling2d_launch_options( + output_shape: &Shape, + config: Tiling2dConfig, +) -> WorkGroup { + let num_rows = output_shape.dims[D - 2]; + let num_cols = output_shape.dims[D - 1]; + + // set number of workgroups + let blocks_needed_in_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32; + let blocks_needed_in_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32; + let mut num_iter = 1; + for i in 0..D - 2 { + num_iter *= output_shape.dims[i]; + } + + WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32) +} diff --git a/crates/burn-jit/src/kernel/matmul/mod.rs b/crates/burn-jit/src/kernel/matmul/mod.rs index 920db368b0..324827ea41 100644 --- a/crates/burn-jit/src/kernel/matmul/mod.rs +++ b/crates/burn-jit/src/kernel/matmul/mod.rs @@ -1,13 +1,22 @@ mod base; -mod mem_coalescing; +mod simple; mod tiling2d; +mod tiling2d_shader; mod tune; /// Contains utilitary for matmul operation pub mod utils; pub use base::*; -pub use mem_coalescing::*; -pub use tiling2d::*; +pub use simple::*; pub use tune::*; pub use utils::*; + +#[cfg(feature = "export_tests")] +#[allow(missing_docs)] +pub mod padding; + +#[cfg(not(feature = "export_tests"))] +mod padding; + +pub use tiling2d::*; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d/padding.rs b/crates/burn-jit/src/kernel/matmul/padding.rs similarity index 100% rename from crates/burn-jit/src/kernel/matmul/tiling2d/padding.rs rename to crates/burn-jit/src/kernel/matmul/padding.rs diff --git a/crates/burn-jit/src/kernel/matmul/mem_coalescing.rs b/crates/burn-jit/src/kernel/matmul/simple.rs similarity index 89% rename from crates/burn-jit/src/kernel/matmul/mem_coalescing.rs rename to crates/burn-jit/src/kernel/matmul/simple.rs index 8fb305338d..6bc4d2489e 100644 --- a/crates/burn-jit/src/kernel/matmul/mem_coalescing.rs +++ b/crates/burn-jit/src/kernel/matmul/simple.rs @@ -6,15 +6,15 @@ use crate::{ dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, InputInfo, OutputInfo, WorkgroupLaunch, }, - compute::WorkGroup, element::JitElement, kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT}, tensor::JitTensor, Runtime, }; -use burn_tensor::Shape; use std::marker::PhantomData; +use super::simple_launch_options; + #[derive(new, Debug)] struct MatmulEagerKernel { workgroup_size_x: usize, @@ -213,11 +213,11 @@ pub fn matmul_mem_coalescing_default( rhs: JitTensor, out: JitTensor, ) -> JitTensor { - matmul_mem_coalescing::(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT) + matmul_simple::(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT) } /// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes -pub fn matmul_mem_coalescing( +pub fn matmul_simple( lhs: JitTensor, rhs: JitTensor, out: JitTensor, @@ -228,7 +228,7 @@ pub fn matmul_mem_coalescing( let lhs = into_contiguous(lhs); let rhs = into_contiguous(rhs); - let workgroup = launch_options( + let workgroup = simple_launch_options( &lhs.shape, &rhs.shape, &out.shape, @@ -242,9 +242,8 @@ pub fn matmul_mem_coalescing( &[ EagerHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), - EagerHandle::new(&out.handle, &out.strides, &out.shape.dims), ], - &[], + &[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)], None, kernel, WorkgroupLaunch::Custom(workgroup), @@ -253,24 +252,3 @@ pub fn matmul_mem_coalescing( out } - -fn launch_options( - lhs_shape: &Shape, - rhs_shape: &Shape, - output_shape: &Shape, - workgroup_size_x: usize, - workgroup_size_y: usize, -) -> WorkGroup { - let num_rows = lhs_shape.dims[D - 2]; - let num_cols = rhs_shape.dims[D - 1]; - - // set number of workgroups - let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; - let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; - let mut num_iter = 1; - for i in 0..D - 2 { - num_iter *= output_shape.dims[i]; - } - - WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32) -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs new file mode 100644 index 0000000000..d559097431 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -0,0 +1,190 @@ +use burn_tensor::{Element, Shape}; + +use crate::{ + codegen::{ + dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, + Execution, InputInfo, OutputInfo, WorkgroupLaunch, + }, + element::JitElement, + kernel::{into_contiguous, DynamicKernelSource, SourceTemplate}, + tensor::JitTensor, + Runtime, +}; +use std::marker::PhantomData; + +use super::{ + padding::{crop, pad_round, PaddingOutput}, + shape_out, tiling2d_launch_options, + tiling2d_shader::MatmulTiling2dShader, + Tiling2dConfig, +}; + +#[derive(new, Debug)] +struct MatmulTiling2d { + _elem: PhantomData, +} + +#[derive(new, Debug)] +struct MatmulTiling2dEagerKernel { + config: Tiling2dConfig, + bounds_check_required: bool, + _runtime: PhantomData, +} + +impl DynamicKernelSource for MatmulTiling2dEagerKernel { + fn source(&self) -> SourceTemplate { + let mut scope = gpu::Scope::root(); + let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into()); + let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into()); + let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into()); + + scope.write_global_custom(out); + + MatmulTiling2dShader { + variables: gpu::BinaryOperator { lhs, rhs, out }, + config: self.config.clone(), + bounds_check_required: self.bounds_check_required, + unroll: true, + } + .expand(&mut scope); + + let lhs = InputInfo::Array { + item: gpu::Elem::Float.into(), + visibility: gpu::Visibility::Read, + }; + let rhs = InputInfo::Array { + item: gpu::Elem::Float.into(), + visibility: gpu::Visibility::Read, + }; + let out = OutputInfo::Array { + item: gpu::Elem::Float.into(), + }; + + let info = CompilationInfo { + inputs: vec![lhs, rhs], + outputs: vec![out], + scope, + }; + + let settings = CompilationSettings::default().workgroup_size(gpu::WorkgroupSize::new( + self.config.grid_x as u32, + self.config.grid_y as u32, + 1, + )); + let shader = Compilation::new(info).compile(settings); + let shader = ::compile(shader); + SourceTemplate::new(shader.to_string()) + } + + fn id(&self) -> String { + format!( + "{:?}config={:?}boundcheck={:?}", + core::any::TypeId::of::(), + self.config, + self.bounds_check_required + ) + } +} + +/// Matrix multiplication using tiling 2d algorithm with +/// vec4 primitive on both lhs and rhs, with no padding needed +pub fn matmul_tiling_2d( + lhs: JitTensor, + rhs: JitTensor, + out: JitTensor, + config: Tiling2dConfig, +) -> JitTensor { + let bounds_check_required = check_bound_requirement(&lhs.shape, &rhs.shape, &config); + + let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), bounds_check_required); + let client = lhs.client.clone(); + + let lhs = match lhs.batch_swapped_with_row_col() { + true => into_contiguous(lhs), + false => lhs, + }; + let rhs = match rhs.batch_swapped_with_row_col() { + true => into_contiguous(rhs), + false => rhs, + }; + + Execution::start(kernel, client) + .inputs(&[ + EagerHandle::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + ]) + .outputs(&[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)]) + .execute(WorkgroupLaunch::Custom(tiling2d_launch_options( + &out.shape, config, + ))); + + out +} + +/// Matrix multiplication using tiling 2d algorithm with padding needed +pub fn matmul_tiling_2d_padded( + lhs: JitTensor, + rhs: JitTensor, + out: JitTensor, + config: Tiling2dConfig, +) -> JitTensor { + let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), false); + let client = lhs.client.clone(); + + // A tensor may need to be padded, in which case it will implicitly become contiguous + // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. + // If batches were swapped among themselves, or if the last two dims are transposed, the underlying + // kernel handles it without needing to turn it into contiguous. + let round_lhs = pad_round::(lhs, config.block_size_m, config.block_size_k); + let lhs = match round_lhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_lhs.into_tensor(), + }; + let round_rhs = pad_round::(rhs, config.block_size_k, config.block_size_n); + let rhs = match round_rhs { + PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + into_contiguous(tensor) + } + _ => round_rhs.into_tensor(), + }; + + let rounded_output_shape = shape_out(&lhs, &rhs); + + let num_elems = rounded_output_shape.num_elements(); + let buffer = client.empty(num_elems * core::mem::size_of::()); + let rounded_output = JitTensor::new( + rhs.client.clone(), + rhs.device.clone(), + rounded_output_shape.clone(), + buffer, + ); + + Execution::start(kernel, client) + .inputs(&[ + EagerHandle::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + ]) + .outputs(&[EagerHandle::new( + &rounded_output.handle, + &rounded_output.strides, + &rounded_output.shape.dims, + )]) + .execute(WorkgroupLaunch::Custom(tiling2d_launch_options( + &rounded_output.shape, + config, + ))); + + crop(rounded_output, out) +} + +fn check_bound_requirement( + lhs_shape: &Shape, + rhs_shape: &Shape, + config: &Tiling2dConfig, +) -> bool { + lhs_shape.dims[D - 2] % config.block_size_m != 0 + || lhs_shape.dims[D - 1] % config.block_size_k != 0 + || rhs_shape.dims[D - 1] % config.block_size_n != 0 +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d/base.rs deleted file mode 100644 index 0167464fb6..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d/base.rs +++ /dev/null @@ -1,91 +0,0 @@ -use super::padding::{crop, pad_round, PaddingOutput}; -use crate::{ - compute::{DynamicKernel, WorkGroup}, - element::JitElement, - kernel::{build_info, into_contiguous, matmul::utils::shape_out, DynamicKernelSource}, - ops::numeric::empty_device, - tensor::JitTensor, - Runtime, -}; -use burn_compute::server::Handle; -use burn_tensor::Shape; - -pub(crate) const B_M: usize = 64; -pub(crate) const B_N: usize = 64; -pub(crate) const B_K: usize = 32; -pub(crate) const WORKGROUP_SIZE: usize = 16; - -pub(super) fn make_workgroup(output_shape: &Shape) -> WorkGroup { - let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32; - let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32; - let mut num_blocks_z = 1; - for i in 0..D - 2 { - num_blocks_z *= output_shape.dims[i]; - } - - WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) -} - -pub(super) fn make_info_handle( - lhs: &JitTensor, - rhs: &JitTensor, - output: &JitTensor, -) -> Handle { - let info = build_info(&[lhs, rhs, output]); - rhs.client.create(bytemuck::cast_slice(&info)) -} - -#[allow(clippy::too_many_arguments)] -pub(super) fn matmul_tiling_2d_launch< - R: Runtime, - E: JitElement, - const D: usize, - K: DynamicKernelSource + 'static, ->( - lhs: JitTensor, - rhs: JitTensor, - output: JitTensor, - kernel: K, -) -> JitTensor { - // A tensor may need to be padded, in which case it will implicitly become contiguous - // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. - // If batches were swapped among themselves, or if the last two dims are transposed, the underlying - // kernel handles it without needing to turn it into contiguous. - let round_lhs = pad_round::(lhs, B_M, B_K); - let lhs = match round_lhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_lhs.into_tensor(), - }; - let round_rhs = pad_round::(rhs, B_K, B_N); - let rhs = match round_rhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { - into_contiguous(tensor) - } - _ => round_rhs.into_tensor(), - }; - - let rounded_output_shape = shape_out(&lhs, &rhs); - - let rounded_output = empty_device( - rhs.client.clone(), - rhs.device.clone(), - rounded_output_shape.clone(), - ); - - let workgroup = make_workgroup(&rounded_output_shape); - let info_handle = make_info_handle(&lhs, &rhs, &rounded_output); - - lhs.client.execute( - Box::new(DynamicKernel::new(kernel, workgroup)), - &[ - &lhs.handle, - &rhs.handle, - &rounded_output.handle, - &info_handle, - ], - ); - - crop(rounded_output, output) -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d/mod.rs deleted file mode 100644 index 02f60c8148..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -mod base; - -#[cfg(feature = "export_tests")] -#[allow(missing_docs)] -pub mod padding; - -#[cfg(not(feature = "export_tests"))] -mod padding; - -/// WGSL vec4 primitives are used on left and right hand tensor, -/// padding is avoided through the use of conditions in the kernel -pub mod unpadded; -/// WGSL vec4 primitives are used on left and right hand tensor -pub mod vec4; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d/unpadded.rs b/crates/burn-jit/src/kernel/matmul/tiling2d/unpadded.rs deleted file mode 100644 index cf88d70cac..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d/unpadded.rs +++ /dev/null @@ -1,74 +0,0 @@ -use burn_tensor::Element; - -use crate::{ - compute::DynamicKernel, - element::JitElement, - kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource}, - tensor::JitTensor, - Runtime, -}; -use std::marker::PhantomData; - -use crate::kernel_wgsl; - -use super::base::{make_info_handle, make_workgroup, B_K, B_M, B_N, WORKGROUP_SIZE}; - -kernel_wgsl!( - MatmulTiling2DUnpaddedRaw, - "../../../template/matmul/blocktiling_2d/unpadded.wgsl" -); - -#[derive(new, Debug)] -struct MatmulTiling2DUnpadded { - _elem: PhantomData, -} - -impl DynamicKernelSource for MatmulTiling2DUnpadded { - fn source(&self) -> SourceTemplate { - MatmulTiling2DUnpaddedRaw::source() - .register("b_m", B_M.to_string()) - .register("b_n", B_N.to_string()) - .register("b_k", B_K.to_string()) - .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) - .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) - .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_z", "1".to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - std::format!("{:?}", self) - } -} - -/// Matrix multiplication using tiling 2d algorithm with -/// vec4 primitive on both lhs and rhs, with no padding needed -pub fn matmul_tiling_2d_unpadded( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, -) -> JitTensor { - let lhs = match lhs.batch_swapped_with_row_col() { - true => into_contiguous(lhs), - false => lhs, - }; - let rhs = match rhs.batch_swapped_with_row_col() { - true => into_contiguous(rhs), - false => rhs, - }; - - let workgroup = make_workgroup(&out.shape); - let info_handle = make_info_handle(&lhs, &rhs, &out); - - lhs.client.execute( - Box::new(DynamicKernel::new( - MatmulTiling2DUnpadded::::new(), - workgroup, - )), - &[&lhs.handle, &rhs.handle, &out.handle, &info_handle], - ); - - out -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d/vec4.rs b/crates/burn-jit/src/kernel/matmul/tiling2d/vec4.rs deleted file mode 100644 index 24d1b91ed9..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d/vec4.rs +++ /dev/null @@ -1,49 +0,0 @@ -use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE}; -use crate::{ - element::JitElement, - kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, - tensor::JitTensor, -}; -use crate::{kernel_wgsl, Runtime}; -use std::marker::PhantomData; - -kernel_wgsl!( - MatmulTiling2Dvec4Raw, - "../../../template/matmul/blocktiling_2d/vec4.wgsl" -); - -#[derive(new, Debug)] -struct MatmulTiling2Dvec4 { - _elem: PhantomData, -} - -impl DynamicKernelSource for MatmulTiling2Dvec4 { - fn source(&self) -> SourceTemplate { - MatmulTiling2Dvec4Raw::source() - .register("b_m", B_M.to_string()) - .register("b_n", B_N.to_string()) - .register("b_k", B_K.to_string()) - .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) - .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) - .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) - .register("workgroup_size_z", "1".to_string()) - .register("elem", E::type_name()) - .register("int", "i32") - } - - fn id(&self) -> String { - std::format!("{:?}", self) - } -} - -/// Matrix multiplication using tiling 2d algorithm with -/// vec4 primitive on both lhs and rhs -pub fn matmul_tiling_2d_vec4( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, -) -> JitTensor { - let kernel = MatmulTiling2Dvec4::::new(); - matmul_tiling_2d_launch::(lhs, rhs, out, kernel) -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs new file mode 100644 index 0000000000..9af7ed2391 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs @@ -0,0 +1,68 @@ +use crate::gpu::{gpu, BinaryOperator, Scope, Synchronization, Variable}; + +use crate::kernel::matmul::tiling2d_shader::{ + computation_loop, gather_shader_information, load_shared_memory, write_to_output, +}; +use crate::kernel::matmul::Tiling2dConfig; + +pub(crate) struct MatmulTiling2dShader { + pub variables: BinaryOperator, + pub config: Tiling2dConfig, + pub bounds_check_required: bool, + pub unroll: bool, +} + +pub(crate) struct Tiling2dState { + pub n_loops: Variable, + pub k: Variable, + pub lhs: Variable, + pub rhs: Variable, + pub out: Variable, + pub offset_lhs: Variable, + pub offset_rhs: Variable, + pub offset_output: Variable, + pub row: Variable, + pub col: Variable, + pub dim_m: Variable, + pub dim_k: Variable, + pub dim_n: Variable, + pub thread_col: Variable, + pub thread_row: Variable, + pub shared_lhs: Variable, + pub shared_rhs: Variable, + pub register_m: Variable, + pub register_n: Variable, + pub results: Variable, + pub lhs_stride_col: Variable, + pub lhs_stride_row: Variable, + pub rhs_stride_col: Variable, + pub rhs_stride_row: Variable, + pub out_stride_row: Variable, + pub out_stride_col: Variable, +} + +impl MatmulTiling2dShader { + pub(crate) fn expand(self, scope: &mut Scope) { + let shader_state = gather_shader_information(scope, &self); + + let block_size_k: Variable = self.config.block_size_k.into(); + gpu!( + scope, + range(0u32, shader_state.n_loops).for_each(|i, scope| { + // From 0 to K with steps block_size_k + let k = shader_state.k; + gpu!(scope, k = i * block_size_k); + + load_shared_memory(scope, &self, &shader_state); + + scope.register(Synchronization::WorkgroupBarrier); + + computation_loop(scope, &self, &shader_state); + + scope.register(Synchronization::WorkgroupBarrier); + }) + ); + + write_to_output(scope, &self, &shader_state); + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs new file mode 100644 index 0000000000..47ca91e561 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs @@ -0,0 +1,82 @@ +use crate::gpu::{gpu, Elem, Scope, Variable}; + +use super::{MatmulTiling2dShader, Tiling2dState}; + +#[allow(clippy::too_many_arguments)] +pub fn computation_loop( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, +) { + let thread_col = shader_state.thread_col; + let thread_row = shader_state.thread_row; + let shared_lhs = shader_state.shared_lhs; + let shared_rhs = shader_state.shared_rhs; + let register_m = shader_state.register_m; + let register_n = shader_state.register_n; + let results = shader_state.results; + + let block_size_k: Variable = shader.config.block_size_k.into(); + let block_size_n: Variable = shader.config.block_size_n.into(); + let elem = results.item().elem(); + + let lhs_sm_position = scope.create_local(Elem::UInt); + let rhs_sm_position = scope.create_local(Elem::UInt); + + let registered_m = scope.create_local(elem); + let registered_n = scope.create_local(elem); + + let multiplied = scope.create_local(elem); + let results_position = scope.create_local(Elem::UInt); + let results_before = scope.create_local(elem); + let results_after = scope.create_local(elem); + + gpu!( + scope, + range(0u32, shader.config.block_size_k as u32, shader.unroll).for_each( + |dot_index, scope| { + // Load a subcolumn of values from lhs + gpu!(scope, lhs_sm_position = thread_row / 4u32); + gpu!(scope, lhs_sm_position *= block_size_k); + gpu!(scope, lhs_sm_position += dot_index); + gpu!(scope, register_m = shared_lhs[lhs_sm_position]); + + // Load a subrow of values from rhs + gpu!(scope, rhs_sm_position = dot_index * block_size_n); + gpu!(scope, rhs_sm_position += thread_col); + gpu!(scope, rhs_sm_position = rhs_sm_position / 4u32); + gpu!(scope, register_n = shared_rhs[rhs_sm_position]); + + gpu!( + scope, + range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each( + |res_idx_m, scope| { + gpu!( + scope, + range(0u32, shader.config.tile_size_n as u32, shader.unroll) + .for_each(|res_idx_n, scope| { + gpu!(scope, registered_m = register_m[res_idx_m]); + gpu!(scope, registered_n = register_n[res_idx_n]); + + gpu!(scope, multiplied = registered_m * registered_n); + + gpu!( + scope, + results_position = + res_idx_m * shader.config.tile_size_n + ); + gpu!(scope, results_position += res_idx_n); + + gpu!(scope, results_before = results[results_position]); + gpu!(scope, results_after = results_before + multiplied); + + gpu!(scope, results[results_position] = results_after); + }) + ); + } + ) + ); + } + ) + ); +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs new file mode 100644 index 0000000000..fd762d8f36 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs @@ -0,0 +1,278 @@ +use crate::gpu::{gpu, Elem, Scope, Variable}; + +use super::{MatmulTiling2dShader, Tiling2dState}; + +enum InputIdentifier { + Lhs, + Rhs, +} + +pub(crate) fn load_shared_memory( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, +) { + if shader.bounds_check_required { + load_shared_memory_with_bound_check(scope, shader, shader_state, InputIdentifier::Lhs); + load_shared_memory_with_bound_check(scope, shader, shader_state, InputIdentifier::Rhs); + } else { + load_shared_memory_no_bound_check(scope, shader, shader_state, InputIdentifier::Lhs); + load_shared_memory_no_bound_check(scope, shader, shader_state, InputIdentifier::Rhs); + } +} + +#[allow(clippy::too_many_arguments)] +fn load_shared_memory_with_bound_check( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, + input_identifier: InputIdentifier, +) { + let ( + input, + input_offset, + shared_memory, + thread_idx_1, + thread_idx_2, + stride_1, + stride_2, + dim, + pos_in_dim, + ) = match input_identifier { + InputIdentifier::Lhs => ( + shader_state.lhs, + shader_state.offset_lhs, + shader_state.shared_lhs, + shader_state.thread_col, + shader_state.thread_row, + shader_state.lhs_stride_col, + shader_state.lhs_stride_row, + shader_state.dim_m, + shader_state.row, + ), + InputIdentifier::Rhs => ( + shader_state.rhs, + shader_state.offset_rhs, + shader_state.shared_rhs, + shader_state.thread_row, + shader_state.thread_col, + shader_state.rhs_stride_row, + shader_state.rhs_stride_col, + shader_state.dim_n, + shader_state.col, + ), + }; + let k = shader_state.k; + let dim_k = shader_state.dim_k; + + // How close is the thread to the end of the matrix. + // If < 4 then it is an edge case + let remain = scope.create_local(Elem::UInt); + gpu!(scope, remain = dim - pos_in_dim); + + let block_size_k: Variable = shader.config.block_size_k.into(); + let block_size_n: Variable = shader.config.block_size_n.into(); + let elem = input.item().elem(); + + let current = scope.create_local(Elem::UInt); + let aligned_with_shared_memory = scope.create_local(Elem::Bool); + let sm_position = scope.create_local(Elem::UInt); + let within_input = scope.create_local(Elem::Bool); + let current_with_k = scope.create_local(Elem::UInt); + let remain_at_least_1 = scope.create_local(Elem::Bool); + let read_condition = scope.create_local(Elem::Bool); + let val_vec4 = scope.create_local(shared_memory.item()); + + let tmp = scope.create_local(Elem::UInt); + let position_0 = scope.create_local(Elem::UInt); + let position_1 = scope.create_local(Elem::UInt); + let position_2 = scope.create_local(Elem::UInt); + let position_3 = scope.create_local(Elem::UInt); + let remain_n = scope.create_local(Elem::Bool); + + let val_0 = scope.create_local(elem); + let val_1 = scope.create_local(elem); + let val_2 = scope.create_local(elem); + let val_3 = scope.create_local(elem); + let zero: Variable = 0u32.into(); + + gpu!( + scope, + range(0_u32, 4u32, shader.unroll).for_each(|j, scope| { + gpu!(scope, current = thread_idx_1 + j); + + gpu!(scope, aligned_with_shared_memory = current < block_size_k); + + // To avoid overwriting following row in shared memory + gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ + + // Position in shared memory + match input_identifier { + InputIdentifier::Lhs => { + gpu!(scope, sm_position = thread_idx_2 / 4u32); + gpu!(scope, sm_position *= block_size_k); + gpu!(scope, sm_position += current); + }, + InputIdentifier::Rhs => { + gpu!(scope, sm_position = current * block_size_n); + gpu!(scope, sm_position += thread_idx_2); + gpu!(scope, sm_position = sm_position / 4u32); + } + } + + // To pad with zeros if outside lhs + gpu!(scope, current_with_k = current + k); + gpu!(scope, within_input = current_with_k < dim_k); + gpu!(scope, remain_at_least_1 = remain >= 1u32); + gpu!(scope, read_condition = within_input && remain_at_least_1); + + gpu!(scope, if(read_condition).then(|scope| { + gpu!(scope, position_0 = k + current); + gpu!(scope, position_0 *= stride_1); + gpu!(scope, tmp = thread_idx_2 * stride_2); + gpu!(scope, position_0 += tmp); + gpu!(scope, position_0 += input_offset); + gpu!(scope, position_1 = position_0 + stride_2); + gpu!(scope, position_2 = position_1 + stride_2); + gpu!(scope, position_3 = position_2 + stride_2); + + gpu!(scope, remain_n = remain >= 4u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + gpu!(scope, val_2 = input[position_2]); + gpu!(scope, val_3 = input[position_3]); + + }).else(|scope|{ + gpu!(scope, remain_n = remain == 3u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + gpu!(scope, val_2 = input[position_2]); + gpu!(scope, val_3 = zero); + + }).else(|scope|{ + gpu!(scope, remain_n = remain == 2u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + gpu!(scope, val_2 = zero); + gpu!(scope, val_3 = zero); + + }).else(|scope|{ + gpu!(scope, remain_n = remain == 1u32); + gpu!(scope, if(remain_n).then(|scope|{ + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = zero); + gpu!(scope, val_2 = zero); + gpu!(scope, val_3 = zero); + })); + })); + })); + })); + + gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); + gpu!(scope, shared_memory[sm_position] = val_vec4); + + }).else(|scope|{ + gpu!(scope, val_0 = zero); + gpu!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0)); + gpu!(scope, shared_memory[sm_position] = val_vec4); + })); + })); + }) + ); +} + +#[allow(clippy::too_many_arguments)] +fn load_shared_memory_no_bound_check( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, + input_identifier: InputIdentifier, +) { + let (input, input_offset, shared_memory, thread_idx_1, thread_idx_2, stride_1, stride_2) = + match input_identifier { + InputIdentifier::Lhs => ( + shader_state.lhs, + shader_state.offset_lhs, + shader_state.shared_lhs, + shader_state.thread_col, + shader_state.thread_row, + shader_state.lhs_stride_col, + shader_state.lhs_stride_row, + ), + InputIdentifier::Rhs => ( + shader_state.rhs, + shader_state.offset_rhs, + shader_state.shared_rhs, + shader_state.thread_row, + shader_state.thread_col, + shader_state.rhs_stride_row, + shader_state.rhs_stride_col, + ), + }; + let k = shader_state.k; + + let block_size_k: Variable = shader.config.block_size_k.into(); + let block_size_n: Variable = shader.config.block_size_n.into(); + let elem = input.item().elem(); + + let current = scope.create_local(Elem::UInt); + let aligned_with_shared_memory = scope.create_local(Elem::Bool); + let sm_position = scope.create_local(Elem::UInt); + + let tmp = scope.create_local(Elem::UInt); + let position_0 = scope.create_local(Elem::UInt); + let position_1 = scope.create_local(Elem::UInt); + let position_2 = scope.create_local(Elem::UInt); + let position_3 = scope.create_local(Elem::UInt); + let val_0 = scope.create_local(elem); + let val_1 = scope.create_local(elem); + let val_2 = scope.create_local(elem); + let val_3 = scope.create_local(elem); + let val_vec4 = scope.create_local(shared_memory.item()); + + gpu!( + scope, + range(0_u32, 4u32, shader.unroll).for_each(|j, scope| { + gpu!(scope, current = thread_idx_1 + j); + + gpu!(scope, aligned_with_shared_memory = current < block_size_k); + + // To avoid overwriting following row in shared memory + gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ + + match input_identifier { + InputIdentifier::Lhs => { + gpu!(scope, sm_position = thread_idx_2 / 4u32); + gpu!(scope, sm_position *= block_size_k); + gpu!(scope, sm_position += current); + }, + InputIdentifier::Rhs => { + gpu!(scope, sm_position = current * block_size_n); + gpu!(scope, sm_position += thread_idx_2); + gpu!(scope, sm_position = sm_position / 4u32); + } + } + + gpu!(scope, position_0 = k + current); + gpu!(scope, position_0 *= stride_1); + gpu!(scope, tmp = thread_idx_2 * stride_2); + gpu!(scope, position_0 += tmp); + gpu!(scope, position_0 += input_offset); + gpu!(scope, position_1 = position_0 + stride_2); + gpu!(scope, position_2 = position_1 + stride_2); + gpu!(scope, position_3 = position_2 + stride_2); + + gpu!(scope, val_0 = input[position_0]); + gpu!(scope, val_1 = input[position_1]); + gpu!(scope, val_2 = input[position_2]); + gpu!(scope, val_3 = input[position_3]); + + gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); + gpu!(scope, shared_memory[sm_position] = val_vec4); + })); + }) + ); +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/mod.rs new file mode 100644 index 0000000000..3ed28903d7 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/mod.rs @@ -0,0 +1,11 @@ +mod base; +mod computation; +mod load_shared_memory; +mod shader_information; +mod write_output; + +pub(crate) use base::*; +pub(crate) use computation::*; +pub(crate) use load_shared_memory::*; +pub(crate) use shader_information::*; +pub(crate) use write_output::*; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs new file mode 100644 index 0000000000..fca13cebed --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs @@ -0,0 +1,180 @@ +use crate::gpu::{gpu, Elem, Item, Scope, Variable}; + +use super::{MatmulTiling2dShader, Tiling2dState}; + +pub(crate) fn gather_shader_information( + scope: &mut Scope, + shader: &MatmulTiling2dShader, +) -> Tiling2dState { + // Inputs + let lhs = shader.variables.lhs; + let rhs = shader.variables.rhs; + let out = shader.variables.out; + + // Config variables + let block_size_m: Variable = shader.config.block_size_m.into(); + let block_size_k: Variable = shader.config.block_size_k.into(); + let block_size_n: Variable = shader.config.block_size_n.into(); + let tile_size_m: Variable = shader.config.tile_size_m.into(); + let tile_size_n: Variable = shader.config.tile_size_n.into(); + let n_threads_per_row: Variable = + (((shader.config.block_size_n - 1) / shader.config.tile_size_n) + 1).into(); + let results_size = (shader.config.tile_size_m * shader.config.tile_size_n) as u32; + + // Shader info + let local_idx = Variable::LocalInvocationIndex; + let batch = Variable::GlobalInvocationIdZ; + + // Shapes + let rank = Variable::Rank; + let last_dim = scope.create_local(Elem::UInt); + let second_to_last_dim = scope.create_local(Elem::UInt); + let dim_m = scope.create_local(Elem::UInt); + let dim_k = scope.create_local(Elem::UInt); + let dim_n = scope.create_local(Elem::UInt); + gpu!(scope, last_dim = rank - 1u32); + gpu!(scope, second_to_last_dim = rank - 2u32); + gpu!(scope, dim_m = shape(lhs, second_to_last_dim)); + gpu!(scope, dim_k = shape(lhs, last_dim)); + gpu!(scope, dim_n = shape(rhs, last_dim)); + + // Strides + let lhs_stride_row = scope.create_local(Elem::UInt); + let lhs_stride_col = scope.create_local(Elem::UInt); + let rhs_stride_row = scope.create_local(Elem::UInt); + let rhs_stride_col = scope.create_local(Elem::UInt); + let out_stride_row = scope.create_local(Elem::UInt); + let out_stride_col = scope.create_local(Elem::UInt); + gpu!(scope, lhs_stride_row = stride(lhs, second_to_last_dim)); + gpu!(scope, lhs_stride_col = stride(lhs, last_dim)); + gpu!(scope, rhs_stride_row = stride(rhs, second_to_last_dim)); + gpu!(scope, rhs_stride_col = stride(rhs, last_dim)); + gpu!(scope, out_stride_row = stride(out, second_to_last_dim)); + gpu!(scope, out_stride_col = stride(out, last_dim)); + + // Workgroup offset + let skip_row = scope.create_local(Elem::UInt); + let skip_col = scope.create_local(Elem::UInt); + let workgroup_id_x = Variable::WorkgroupIdX; + let workgroup_id_y = Variable::WorkgroupIdY; + gpu!(scope, skip_row = workgroup_id_x); + gpu!(scope, skip_row *= block_size_m); + gpu!(scope, skip_col = workgroup_id_y); + gpu!(scope, skip_col *= block_size_n); + + // Position of the first element of the thread, relative to the block + let thread_row = scope.create_local(Elem::UInt); + let thread_col = scope.create_local(Elem::UInt); + gpu!(scope, thread_row = local_idx / n_threads_per_row); + gpu!(scope, thread_row *= tile_size_m); + gpu!(scope, thread_col = local_idx % n_threads_per_row); + gpu!(scope, thread_col *= tile_size_n); + + // Position of the first element of the thread, in absolute (in one batch) + let row = scope.create_local(Elem::UInt); + let col = scope.create_local(Elem::UInt); + gpu!(scope, row = skip_row + thread_row); + gpu!(scope, col = skip_col + thread_col); + + // Calculate offset. + let offset_lhs = scope.create_local(Elem::UInt); + let offset_rhs = scope.create_local(Elem::UInt); + gpu!(scope, offset_lhs = skip_row * lhs_stride_row); + gpu!(scope, offset_rhs = skip_col * rhs_stride_col); + + // Batch offset for the output. + let offset_output = scope.create_local(Elem::UInt); + let batch_dims = scope.create_local(Elem::UInt); + gpu!(scope, offset_output = dim_m * dim_n); + gpu!(scope, offset_output = offset_output * batch); + + // Batch offset for the lhs & rhs matrices. + let stride_lhs = scope.create_local(Elem::UInt); + let stride_rhs = scope.create_local(Elem::UInt); + let stride_output = scope.create_local(Elem::UInt); + let shape_lhs = scope.create_local(Elem::UInt); + let shape_rhs = scope.create_local(Elem::UInt); + let tmp = scope.create_local(Elem::UInt); + let tmp_lhs = scope.create_local(Elem::UInt); + let tmp_rhs = scope.create_local(Elem::UInt); + gpu!(scope, batch_dims = rank - 2u32); + gpu!( + scope, + range(0u32, batch_dims).for_each(|b, scope| { + gpu!(scope, stride_lhs = stride(lhs, b)); + gpu!(scope, stride_rhs = stride(rhs, b)); + gpu!(scope, stride_output = stride(out, b)); + gpu!(scope, shape_lhs = shape(lhs, b)); + gpu!(scope, shape_rhs = shape(rhs, b)); + + gpu!(scope, tmp = offset_output / stride_output); + gpu!(scope, tmp_lhs = tmp % shape_lhs); + gpu!(scope, tmp_lhs = tmp_lhs * stride_lhs); + gpu!(scope, offset_lhs += tmp_lhs); + + gpu!(scope, tmp_rhs = tmp % shape_rhs); + gpu!(scope, tmp_rhs = tmp_rhs * stride_rhs); + gpu!(scope, offset_rhs += tmp_rhs); + }) + ); + + let elem = lhs.item().elem(); + + // Registers used in the compute pass + let results = scope.create_local_array(elem, results_size); + let register_m = scope.create_local(Item::Vec4(elem)); + let register_n = scope.create_local(Item::Vec4(elem)); + let shared_lhs = scope.create_shared( + Item::Vec4(elem), + shader.config.block_size_m as u32 * shader.config.block_size_k as u32 / 4u32, + ); + let shared_rhs = scope.create_shared( + Item::Vec4(elem), + shader.config.block_size_k as u32 * shader.config.block_size_n as u32 / 4u32, + ); + + // Calculate exact number of loop iterations + let n_loops = scope.create_local(Elem::UInt); + let k = scope.create_local(Elem::UInt); + if shader.bounds_check_required { + let dim_k_float = scope.create_local(elem); + let block_size_k_float = scope.create_local(elem); + let n_loops_float = scope.create_local(elem); + gpu!(scope, dim_k_float = dim_k); + gpu!(scope, block_size_k_float = block_size_k); + gpu!(scope, n_loops_float = dim_k_float / block_size_k_float); + gpu!(scope, n_loops_float = ceil(n_loops_float)); + gpu!(scope, n_loops = n_loops_float); + } else { + gpu!(scope, n_loops = dim_k / block_size_k); + } + + Tiling2dState { + n_loops, + k, + lhs, + rhs, + out, + offset_lhs, + offset_rhs, + offset_output, + row, + col, + dim_m, + dim_k, + dim_n, + thread_col, + thread_row, + shared_lhs, + shared_rhs, + register_m, + register_n, + results, + lhs_stride_col, + lhs_stride_row, + rhs_stride_col, + rhs_stride_row, + out_stride_row, + out_stride_col, + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs new file mode 100644 index 0000000000..ea09a0c9cf --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs @@ -0,0 +1,121 @@ +use crate::gpu::{gpu, Elem, Scope, Variable}; + +use super::{MatmulTiling2dShader, Tiling2dState}; + +#[allow(clippy::too_many_arguments)] +pub fn write_to_output( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, +) { + let row = shader_state.row; + let col = shader_state.col; + + let row_index = scope.create_local(Elem::UInt); + let col_index = scope.create_local(Elem::UInt); + + if shader.bounds_check_required { + let dim_m = shader_state.dim_m; + let dim_n = shader_state.dim_n; + + let within_output = scope.create_local(Elem::Bool); + let within_output_tmp = scope.create_local(Elem::Bool); + + gpu!( + scope, + range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each( + |res_idx_m, scope| { + gpu!( + scope, + range(0u32, shader.config.tile_size_n as u32, shader.unroll).for_each( + |res_idx_n, scope| { + gpu!(scope, row_index = row + res_idx_m); + gpu!(scope, col_index = col + res_idx_n); + + gpu!(scope, within_output = row_index < dim_m); + gpu!(scope, within_output_tmp = col_index < dim_n); + gpu!(scope, within_output = within_output && within_output_tmp); + + gpu!(scope, if(within_output).then(|scope|{ + write_inner( + scope, + shader, + shader_state, + res_idx_m, + res_idx_n, + row_index, + col_index, + ); + })); + } + ) + ); + } + ) + ); + } else { + gpu!( + scope, + range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each( + |res_idx_m, scope| { + gpu!( + scope, + range(0u32, shader.config.tile_size_n as u32, shader.unroll).for_each( + |res_idx_n, scope| { + gpu!(scope, row_index = row + res_idx_m); + gpu!(scope, col_index = col + res_idx_n); + + write_inner( + scope, + shader, + shader_state, + res_idx_m, + res_idx_n, + row_index, + col_index, + ) + } + ) + ); + } + ) + ); + } +} + +#[allow(clippy::too_many_arguments)] +fn write_inner( + scope: &mut Scope, + shader: &MatmulTiling2dShader, + shader_state: &Tiling2dState, + res_idx_m: Variable, + res_idx_n: Variable, + row_index: Variable, + col_index: Variable, +) { + let offset_output = shader_state.offset_output; + let out = shader_state.out; + let out_stride_row = shader_state.out_stride_row; + let out_stride_col = shader_state.out_stride_col; + let results = shader_state.results; + + let elem = results.item().elem(); + let results_position = scope.create_local(Elem::UInt); + let result = scope.create_local(elem); + let output_position = scope.create_local(Elem::UInt); + + gpu!( + scope, + results_position = res_idx_m * shader.config.tile_size_n + ); + gpu!(scope, results_position += res_idx_n); + + gpu!(scope, result = results[results_position]); + + gpu!(scope, row_index *= out_stride_row); + gpu!(scope, col_index *= out_stride_col); + gpu!(scope, output_position = row_index + col_index); + gpu!(scope, output_position += offset_output); + + gpu!(scope, out[output_position] = result); +} diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 7208118644..1fa9ab0a9e 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -4,7 +4,10 @@ use burn_tensor::{Element, ElementConversion}; use crate::{ compute::JitAutotuneKey, element::JitElement, - kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, + kernel::{ + matmul::{utils::init_matmul_output, Tiling2dConfig}, + prng::random_like_uniform, + }, ops::numeric::empty_device, tensor::JitTensor, Runtime, @@ -50,22 +53,14 @@ impl AutotuneOperationSet AutotuneOperationSet, fastest_index: usize) -> Box { match fastest_index { - 0 => Box::new(MemoryCoalescingMatmulDefault::new( - self.lhs, self.rhs, self.out, - )), - 1 => Box::new(MemoryCoalescingMatmulW16x16::new( - self.lhs, self.rhs, self.out, - )), - 2 => Box::new(Vec4TilingMatmulDefault::new(self.lhs, self.rhs, self.out)), - 3 => Box::new(Vec4TilingMatmulUnpaddedDefault::new( - self.lhs, self.rhs, self.out, - )), + 0 => Box::new(SimpleMatmul::new(self.lhs, self.rhs, self.out)), + 1 => Box::new(SimpleMatmul16x16::new(self.lhs, self.rhs, self.out)), + 2 => Box::new(Tiling2dMatmul::new(self.lhs, self.rhs, self.out)), + 3 => Box::new(Tiling2dMatmulPadded::new(self.lhs, self.rhs, self.out)), _ => panic!("Fastest index is out of bound"), } } @@ -134,23 +123,21 @@ macro_rules! matmul_tune_ops { // Potentially better for small matrices. matmul_tune_ops!( - MemoryCoalescingMatmulDefault, + SimpleMatmul, crate::kernel::matmul::matmul_mem_coalescing_default ); // Potentially better for small matrices. -matmul_tune_ops!(MemoryCoalescingMatmulW16x16, |lhs, rhs, out| { - crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16) +matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| { + crate::kernel::matmul::matmul_simple(lhs, rhs, out, 16, 16) }); // Probably the fastest when fixed sizes. -matmul_tune_ops!( - Vec4TilingMatmulDefault, - crate::kernel::matmul::vec4::matmul_tiling_2d_vec4 -); +matmul_tune_ops!(Tiling2dMatmulPadded, |lhs, rhs, out| { + crate::kernel::matmul::matmul_tiling_2d_padded(lhs, rhs, out, Tiling2dConfig::default()) +}); -// Probably the fastest otherwise. -matmul_tune_ops!( - Vec4TilingMatmulUnpaddedDefault, - crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded -); +// Probably the fastest in the general case +matmul_tune_ops!(Tiling2dMatmul, |lhs, rhs, out| { + crate::kernel::matmul::matmul_tiling_2d(lhs, rhs, out, Tiling2dConfig::default()) +}); diff --git a/crates/burn-jit/src/tests/matmul.rs b/crates/burn-jit/src/tests/matmul.rs index 0ec2a66bf6..7c85478057 100644 --- a/crates/burn-jit/src/tests/matmul.rs +++ b/crates/burn-jit/src/tests/matmul.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(matmul)] mod tests { use super::*; - use burn_jit::kernel::matmul::{matmul, MatmulStrategy}; + use burn_jit::kernel::matmul::{matmul, MatmulStrategy, Tiling2dConfig}; use burn_tensor::{Shape, Tensor}; mod simple { @@ -174,7 +174,7 @@ mod tests { let shape_lhs = [3, 2, 4, 4]; let shape_rhs = [3, 2, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dPadded, + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), swap, swap, shape_lhs, @@ -189,7 +189,7 @@ mod tests { let shape_lhs = [3, 2, 4, 4]; let shape_rhs = [3, 2, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dPadded, + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), swap_lhs, swap_rhs, shape_lhs, @@ -204,7 +204,7 @@ mod tests { let shape_lhs = [4, 4, 4, 4]; let shape_rhs = [4, 4, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dPadded, + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), swap_lhs, swap_rhs, shape_lhs, @@ -212,10 +212,56 @@ mod tests { ); } + #[test] + fn stable_test() { + let ref_tensor_device = Default::default(); + let x = ReferenceTensor::from_floats([[0., 1., 2.], [3., 4., 5.]], &ref_tensor_device); + let y = + ReferenceTensor::from_floats([[0., 1.], [2., 3.], [4., 5.]], &ref_tensor_device); + + let test_tensor_device = Default::default(); + let x_jit = TestTensor::from_data(x.to_data(), &test_tensor_device); + let y_jit = TestTensor::from_data(y.to_data(), &test_tensor_device); + + let z_reference = x.matmul(y); + let z = Tensor::::from_primitive(matmul( + x_jit.into_primitive(), + y_jit.into_primitive(), + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), + )); + + z_reference.into_data().assert_approx_eq(&z.into_data(), 3); + } + + #[test] + fn stable_test_2() { + let ref_tensor_device = Default::default(); + let x = + ReferenceTensor::from_floats([[0., 1.], [2., 3.], [4., 5.]], &ref_tensor_device); + let y = ReferenceTensor::from_floats([[0., 1., 2.], [3., 4., 5.]], &ref_tensor_device); + + let test_tensor_device = Default::default(); + let x_jit = TestTensor::from_data(x.to_data(), &test_tensor_device); + let y_jit = TestTensor::from_data(y.to_data(), &test_tensor_device); + + let z_reference = x.matmul(y); + let z = Tensor::::from_primitive(matmul( + x_jit.into_primitive(), + y_jit.into_primitive(), + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), + )); + + z_reference.into_data().assert_approx_eq(&z.into_data(), 3); + } + fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { let shape_lhs = [batch_1, batch_2, m, k]; let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(MatmulStrategy::Tiling2dPadded, shape_lhs, shape_rhs); + same_as_reference( + MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), + shape_lhs, + shape_rhs, + ); } } @@ -308,7 +354,7 @@ mod tests { let shape_lhs = [3, 2, 4, 4]; let shape_rhs = [3, 2, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2d, + MatmulStrategy::Tiling2d(Tiling2dConfig::default()), swap, swap, shape_lhs, @@ -323,7 +369,7 @@ mod tests { let shape_lhs = [3, 2, 4, 4]; let shape_rhs = [3, 2, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2d, + MatmulStrategy::Tiling2d((Tiling2dConfig::default())), swap_lhs, swap_rhs, shape_lhs, @@ -338,7 +384,7 @@ mod tests { let shape_lhs = [4, 4, 4, 4]; let shape_rhs = [4, 4, 4, 4]; same_as_reference_swapped_dims( - MatmulStrategy::Tiling2d, + MatmulStrategy::Tiling2d(Tiling2dConfig::default()), swap_lhs, swap_rhs, shape_lhs, @@ -349,7 +395,11 @@ mod tests { fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { let shape_lhs = [batch_1, batch_2, m, k]; let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(MatmulStrategy::Tiling2d, shape_lhs, shape_rhs); + same_as_reference( + MatmulStrategy::Tiling2d(Tiling2dConfig::default()), + shape_lhs, + shape_rhs, + ); } } diff --git a/crates/burn-wgpu/src/compiler/wgsl/base.rs b/crates/burn-wgpu/src/compiler/wgsl/base.rs index 45ad810d32..a268b4d166 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/base.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/base.rs @@ -18,6 +18,7 @@ pub enum Variable { scope_depth: u8, }, SharedMemory(u16, Item, u32), + LocalArray(u16, Item, u8, u32), Id, LocalInvocationIndex, LocalInvocationIdX, @@ -79,6 +80,7 @@ impl Variable { Variable::GlobalInputArray(_, _) => false, Variable::GlobalOutputArray(_, _) => false, Variable::SharedMemory(_, _, _) => false, + Variable::LocalArray(_, _, _, _) => false, Variable::Local { index: _, item: _, @@ -110,6 +112,7 @@ impl Variable { Self::GlobalInputArray(_, e) => *e, Self::GlobalOutputArray(_, e) => *e, Self::SharedMemory(_, e, _) => *e, + Self::LocalArray(_, e, _, _) => *e, Self::Local { index: _, item, @@ -222,6 +225,9 @@ impl Display for Variable { Variable::SharedMemory(number, _, _) => { f.write_fmt(format_args!("shared_memory_{number}")) } + Variable::LocalArray(number, _, scope_depth, _) => { + f.write_fmt(format_args!("a_{scope_depth}_{number}")) + } Variable::Id => f.write_str("id"), Variable::LocalInvocationIndex => f.write_str("local_idx"), Variable::LocalInvocationIdX => f.write_str("local_invocation_id.x"), diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index 349243505b..02dacf9d4c 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -1,3 +1,4 @@ +use super::LocalArray; use super::{shader::ComputeShader, Item, SharedMemory}; use crate::compiler::wgsl; use crate::{FloatElement, IntElement}; @@ -19,6 +20,7 @@ pub struct WgslCompiler { shape: bool, num_workgroups: bool, shared_memories: Vec, + local_arrays: Vec, _float: PhantomData, _int: PhantomData, } @@ -44,6 +46,7 @@ impl Default for WgslCompiler { shape: false, num_workgroups: false, shared_memories: Vec::default(), + local_arrays: Vec::default(), _float: PhantomData, _int: PhantomData, } @@ -64,6 +67,10 @@ impl burn_jit::Compiler for WgslCompiler { fn elem_size(elem: gpu::Elem) -> usize { Self::compile_elem(elem).size() } + + fn max_shared_memory_size() -> usize { + 8192 + } } impl WgslCompiler { @@ -98,6 +105,7 @@ impl WgslCompiler { .map(|(name, binding)| (name, Self::compile_binding(binding))) .collect(), shared_memories: self.shared_memories.clone(), + local_arrays: self.local_arrays.clone(), workgroup_size: value.workgroup_size, global_invocation_id: self.global_invocation_id || self.id, local_invocation_index: self.local_invocation_index, @@ -159,6 +167,14 @@ impl WgslCompiler { } wgsl::Variable::SharedMemory(index, item, size) } + gpu::Variable::LocalArray(index, item, scope_depth, size) => { + let item = Self::compile_item(item); + if !self.local_arrays.iter().any(|s| s.index == index) { + self.local_arrays + .push(LocalArray::new(index, item, scope_depth, size)); + } + wgsl::Variable::LocalArray(index, item, scope_depth, size) + } gpu::Variable::Id => { self.id = true; wgsl::Variable::Id @@ -416,6 +432,10 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(op.out), }, + gpu::Operator::Ceil(op) => wgsl::Instruction::Ceil { + input: self.compile_variable(op.input), + out: self.compile_variable(op.out), + }, gpu::Operator::Log(op) => wgsl::Instruction::Log { input: self.compile_variable(op.input), out: self.compile_variable(op.out), diff --git a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs index 59798bdaca..e0c6826534 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs @@ -109,6 +109,10 @@ pub enum Instruction { input: Variable, out: Variable, }, + Ceil { + input: Variable, + out: Variable, + }, Erf { input: Variable, out: Variable, @@ -272,6 +276,9 @@ impl Display for Instruction { Instruction::Sqrt { input, out } => { f.write_fmt(format_args!("{out} = sqrt({input});\n")) } + Instruction::Ceil { input, out } => { + f.write_fmt(format_args!("{out} = ceil({input});\n")) + } Instruction::Log1p { input, out } => { f.write_fmt(format_args!("{out} = log({input} + 1.0);\n")) } diff --git a/crates/burn-wgpu/src/compiler/wgsl/shader.rs b/crates/burn-wgpu/src/compiler/wgsl/shader.rs index 6036584b0d..8229ac4510 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/shader.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/shader.rs @@ -5,7 +5,6 @@ use std::fmt::Display; #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum Location { Storage, - #[allow(dead_code)] Workgroup, } @@ -42,12 +41,32 @@ impl SharedMemory { } } +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct LocalArray { + pub index: u16, + item: Item, + name: u8, + size: u32, +} + +impl LocalArray { + pub fn new(index: u16, item: Item, name: u8, size: u32) -> Self { + Self { + index, + item, + name, + size, + } + } +} + #[derive(Debug, Clone)] pub struct ComputeShader { pub inputs: Vec, pub outputs: Vec, pub named: Vec<(String, Binding)>, pub shared_memories: Vec, + pub local_arrays: Vec, pub workgroup_size: WorkgroupSize, pub global_invocation_id: bool, pub local_invocation_index: bool, @@ -72,10 +91,10 @@ impl Display for ComputeShader { )?; } - for shared_memory in self.shared_memories.iter() { + for array in self.shared_memories.iter() { f.write_fmt(format_args!( "var<{}> shared_memory_{}: array<{}, {}>;\n\n", - shared_memory.location, shared_memory.index, shared_memory.item, shared_memory.size + array.location, array.index, array.item, array.size ))?; } @@ -115,12 +134,22 @@ fn main( f.write_str(" @builtin(workgroup_id) workgroup_id: vec3,\n")?; } - f.write_fmt(format_args!( - ") {{ - {} -}}", - self.body - ))?; + // Open body + f.write_fmt(format_args!(") {{"))?; + + // Local arrays + for array in self.local_arrays.iter() { + f.write_fmt(format_args!( + "var a_{}_{}: array<{}, {}>;\n\n", + array.name, array.index, array.item, array.size + ))?; + } + + // Body + f.write_fmt(format_args!("{}", self.body))?; + + // Close body + f.write_fmt(format_args!("}}"))?; for extension in self.extensions.iter() { f.write_fmt(format_args!("{extension}\n\n"))?;