diff --git a/crates/burn-jit/src/kernel/interpolate.rs b/crates/burn-jit/src/kernel/interpolate.rs deleted file mode 100644 index 2156c8ae89..0000000000 --- a/crates/burn-jit/src/kernel/interpolate.rs +++ /dev/null @@ -1,108 +0,0 @@ -use crate::{ - compute::{Kernel, StaticKernel}, - element::JitElement, - kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, - tensor::JitTensor, - Runtime, -}; -use burn_tensor::{ - ops::{InterpolateMode, InterpolateOptions}, - Element, Shape, -}; - -kernel_wgsl!(Nearest, "../template/interpolate/nearest.wgsl"); -kernel_wgsl!( - NearestBackward, - "../template/interpolate/nearest_backward.wgsl" -); -kernel_wgsl!(Bilinear, "../template/interpolate/bilinear.wgsl"); -kernel_wgsl!(Bicubic, "../template/interpolate/bicubic.wgsl"); - -pub(crate) fn interpolate( - input: JitTensor, - output_size: [usize; 2], - options: InterpolateOptions, -) -> JitTensor { - let input = kernel::into_contiguous(input); - let [batch_size, channels, _, _] = input.shape.dims; - let [out_height, out_width] = output_size; - - let shape_out = Shape::new([batch_size, channels, out_height, out_width]); - let output = empty_device(input.client.clone(), input.device.clone(), shape_out); - - let info = build_info(&[&input, &output]); - - let info_handle = input.client.create(bytemuck::cast_slice(&info)); - - let kernel: Box = match options.mode { - InterpolateMode::Nearest => Box::new(StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - ))), - InterpolateMode::Bilinear => Box::new(StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - ))), - InterpolateMode::Bicubic => Box::new(StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - ))), - }; - - input - .client - .execute(kernel, &[&input.handle, &output.handle, &info_handle]); - - output -} - -pub(crate) fn interpolate_backward( - input: JitTensor, - out_grad: JitTensor, - _output_size: [usize; 2], - options: InterpolateOptions, -) -> JitTensor { - let out_grad = kernel::into_contiguous(out_grad); - let output_shape = input.shape.clone(); - let num_elems = input.shape.num_elements(); - let buffer = input.client.empty(num_elems * core::mem::size_of::()); - let output = JitTensor::new( - input.client.clone(), - input.device.clone(), - output_shape, - buffer, - ); - - let info = build_info(&[&input, &out_grad]); - - let info_handle = out_grad.client.create(bytemuck::cast_slice(&info)); - - let kernel: Box = match options.mode { - InterpolateMode::Nearest => Box::new(StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup( - output.shape.num_elements(), - WORKGROUP_DEFAULT, - ))), - InterpolateMode::Bilinear => { - panic!("bilinear interpolation backward is not supported by JIT backend") - } - InterpolateMode::Bicubic => { - panic!("bicubic interpolation backward is not supported by JIT backend") - } - }; - - input - .client - .execute(kernel, &[&out_grad.handle, &output.handle, &info_handle]); - - output -} diff --git a/crates/burn-jit/src/kernel/interpolate/base.rs b/crates/burn-jit/src/kernel/interpolate/base.rs new file mode 100644 index 0000000000..7961e3fbb3 --- /dev/null +++ b/crates/burn-jit/src/kernel/interpolate/base.rs @@ -0,0 +1,66 @@ +use crate::{ + element::JitElement, kernel::into_contiguous, ops::numeric::empty_device, tensor::JitTensor, + Runtime, +}; +use burn_tensor::{ + ops::{InterpolateMode, InterpolateOptions}, + Element, Shape, +}; + +use super::{ + bicubic::interpolate_bicubic_launch, bilinear::interpolate_bilinear_launch, + nearest::interpolate_nearest_launch, nearest_backward::interpolate_nearest_backward_launch, +}; + +/// Interpolate operation +/// +/// Supports nearest, bilinear and bicubic modes +pub fn interpolate( + input: JitTensor, + output_size: [usize; 2], + options: InterpolateOptions, +) -> JitTensor { + let input = into_contiguous(input); + let [batch_size, channels, _, _] = input.shape.dims; + let [out_height, out_width] = output_size; + + let shape_out = Shape::new([batch_size, channels, out_height, out_width]); + let output = empty_device(input.client.clone(), input.device.clone(), shape_out); + + match options.mode { + InterpolateMode::Nearest => interpolate_nearest_launch(input, output), + InterpolateMode::Bilinear => interpolate_bilinear_launch(input, output), + InterpolateMode::Bicubic => interpolate_bicubic_launch(input, output), + } +} + +/// Backward interpolate operation +/// +/// Note: only nearest mode is supported +pub fn interpolate_backward( + input: JitTensor, + out_grad: JitTensor, + _output_size: [usize; 2], + options: InterpolateOptions, +) -> JitTensor { + let out_grad = into_contiguous(out_grad); + let output_shape = input.shape.clone(); + let num_elems = input.shape.num_elements(); + let buffer = input.client.empty(num_elems * core::mem::size_of::()); + let output = JitTensor::new( + input.client.clone(), + input.device.clone(), + output_shape, + buffer, + ); + + match options.mode { + InterpolateMode::Nearest => interpolate_nearest_backward_launch(out_grad, output), + InterpolateMode::Bilinear => { + panic!("bilinear interpolation backward is not supported by JIT backend") + } + InterpolateMode::Bicubic => { + panic!("bicubic interpolation backward is not supported by JIT backend") + } + } +} diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs new file mode 100644 index 0000000000..70adb78452 --- /dev/null +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -0,0 +1,429 @@ +use std::marker::PhantomData; + +use crate::{ + codegen::{ + execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo, + OutputInfo, WorkgroupLaunch, + }, + gpu::{gpu, Elem, Scope, Variable, Visibility}, + kernel::{DynamicKernelSource, SourceTemplate}, + tensor::JitTensor, + Compiler, JitElement, Runtime, +}; + +#[derive(new)] +struct InterpolateBicubicEagerKernel { + _runtime: PhantomData, + _elem: PhantomData, +} + +struct InterpolateBicubicShader { + input: Variable, + output: Variable, +} + +impl InterpolateBicubicShader { + pub(crate) fn expand(self, scope: &mut Scope) { + let input = self.input; + let output = self.output; + let id = Variable::Id; + + let input_stride_0 = scope.create_local(Elem::UInt); + let input_stride_1 = scope.create_local(Elem::UInt); + let input_stride_2 = scope.create_local(Elem::UInt); + let input_stride_3 = scope.create_local(Elem::UInt); + + let input_shape_2 = scope.create_local(Elem::UInt); + let input_shape_3 = scope.create_local(Elem::UInt); + + let output_stride_0 = scope.create_local(Elem::UInt); + let output_stride_1 = scope.create_local(Elem::UInt); + let output_stride_2 = scope.create_local(Elem::UInt); + let output_stride_3 = scope.create_local(Elem::UInt); + + let output_shape_0 = scope.create_local(Elem::UInt); + let output_shape_1 = scope.create_local(Elem::UInt); + let output_shape_2 = scope.create_local(Elem::UInt); + let output_shape_3 = scope.create_local(Elem::UInt); + + gpu!(scope, input_stride_0 = stride(input, 0u32)); + gpu!(scope, input_stride_1 = stride(input, 1u32)); + gpu!(scope, input_stride_2 = stride(input, 2u32)); + gpu!(scope, input_stride_3 = stride(input, 3u32)); + + gpu!(scope, input_shape_2 = shape(input, 2u32)); + gpu!(scope, input_shape_3 = shape(input, 3u32)); + + gpu!(scope, output_stride_0 = stride(output, 0u32)); + gpu!(scope, output_stride_1 = stride(output, 1u32)); + gpu!(scope, output_stride_2 = stride(output, 2u32)); + gpu!(scope, output_stride_3 = stride(output, 3u32)); + + gpu!(scope, output_shape_0 = shape(output, 0u32)); + gpu!(scope, output_shape_1 = shape(output, 1u32)); + gpu!(scope, output_shape_2 = shape(output, 2u32)); + gpu!(scope, output_shape_3 = shape(output, 3u32)); + + let b = scope.create_local(Elem::UInt); + let c = scope.create_local(Elem::UInt); + let h = scope.create_local(Elem::UInt); + let w = scope.create_local(Elem::UInt); + + gpu!(scope, b = id / output_stride_0); + gpu!(scope, b = b % output_shape_0); + + gpu!(scope, c = id / output_stride_1); + gpu!(scope, c = c % output_shape_1); + + gpu!(scope, h = id / output_stride_2); + gpu!(scope, h = h % output_shape_2); + + gpu!(scope, w = id / output_stride_3); + gpu!(scope, w = w % output_shape_3); + + let input_height = scope.create_local(Elem::UInt); + let output_height = scope.create_local(Elem::UInt); + let output_height_float = scope.create_local(Elem::Float); + + let input_width = scope.create_local(Elem::UInt); + let output_width = scope.create_local(Elem::UInt); + let output_width_float = scope.create_local(Elem::Float); + + let frac = scope.create_local(Elem::Float); + let numerator = scope.create_local(Elem::UInt); + let numerator_float = scope.create_local(Elem::Float); + let not_zero = scope.create_local(Elem::Bool); + + let y_in_float = scope.create_local(Elem::Float); + let y_in = scope.create_local(Elem::UInt); + let yw = scope.create_local(Elem::Float); + let y_tmp = scope.create_local(Elem::UInt); + + gpu!(scope, input_height = input_shape_2 - 1u32); + gpu!(scope, output_height = output_shape_2 - 1u32); + gpu!(scope, numerator = h * input_height); + gpu!(scope, numerator_float = cast(numerator)); + gpu!(scope, output_height_float = cast(output_height)); + gpu!(scope, frac = numerator_float / output_height_float); + gpu!(scope, y_in_float = floor(frac)); + gpu!(scope, y_in = cast(y_in_float)); + gpu!(scope, yw = frac - y_in_float); + + let y0 = scope.zero(Elem::UInt); + gpu!(scope, not_zero = y_in != 0u32); + gpu!(scope, if(not_zero).then(|scope|{ + gpu!(scope, y0 = y_in - 1u32); + })); + + let y1 = y_in; + + gpu!(scope, y_tmp = y_in + 1u32); + let y2 = Self::min(scope, y_tmp, input_height); + + gpu!(scope, y_tmp = y_in + 2u32); + let y3 = Self::min(scope, y_tmp, input_height); + + let x_in_float = scope.create_local(Elem::Float); + let x_in = scope.create_local(Elem::UInt); + let xw = scope.create_local(Elem::Float); + let x_tmp = scope.create_local(Elem::UInt); + + gpu!(scope, input_width = input_shape_3 - 1u32); + gpu!(scope, output_width = output_shape_3 - 1u32); + gpu!(scope, numerator = w * input_width); + gpu!(scope, numerator_float = cast(numerator)); + gpu!(scope, output_width_float = cast(output_width)); + gpu!(scope, frac = numerator_float / output_width_float); + gpu!(scope, x_in_float = floor(frac)); + gpu!(scope, x_in = cast(x_in_float)); + gpu!(scope, xw = frac - x_in_float); + + let x0 = scope.zero(Elem::UInt); + gpu!(scope, not_zero = x_in != 0u32); + gpu!(scope, if(not_zero).then(|scope|{ + gpu!(scope, x0 = x_in - 1u32); + })); + + gpu!(scope, x_tmp = x_in - 1u32); + let x1 = x_in; + + gpu!(scope, x_tmp = x_in + 1u32); + let x2 = Self::min(scope, x_tmp, input_width); + + gpu!(scope, x_tmp = x_in + 2u32); + let x3 = Self::min(scope, x_tmp, input_width); + + let index_base = scope.create_local(Elem::UInt); + let index_tmp = scope.create_local(Elem::UInt); + gpu!(scope, index_base = b * input_stride_0); + gpu!(scope, index_tmp = c * input_stride_1); + gpu!(scope, index_base += index_tmp); + + let y0_stride = scope.create_local(Elem::UInt); + let y1_stride = scope.create_local(Elem::UInt); + let y2_stride = scope.create_local(Elem::UInt); + let y3_stride = scope.create_local(Elem::UInt); + let x0_stride = scope.create_local(Elem::UInt); + let x1_stride = scope.create_local(Elem::UInt); + let x2_stride = scope.create_local(Elem::UInt); + let x3_stride = scope.create_local(Elem::UInt); + gpu!(scope, y0_stride = y0 * input_stride_2); + gpu!(scope, y1_stride = y1 * input_stride_2); + gpu!(scope, y2_stride = y2 * input_stride_2); + gpu!(scope, y3_stride = y3 * input_stride_2); + gpu!(scope, x0_stride = x0 * input_stride_3); + gpu!(scope, x1_stride = x1 * input_stride_3); + gpu!(scope, x2_stride = x2 * input_stride_3); + gpu!(scope, x3_stride = x3 * input_stride_3); + + let index_0 = scope.create_local(Elem::UInt); + let index_1 = scope.create_local(Elem::UInt); + let index_2 = scope.create_local(Elem::UInt); + let index_3 = scope.create_local(Elem::UInt); + let inp_0 = scope.create_local(input.item()); + let inp_1 = scope.create_local(input.item()); + let inp_2 = scope.create_local(input.item()); + let inp_3 = scope.create_local(input.item()); + + gpu!(scope, index_0 = index_base); + gpu!(scope, index_0 += y0_stride); + gpu!(scope, index_0 += x0_stride); + gpu!(scope, inp_0 = input[index_0]); + gpu!(scope, index_1 = index_base); + gpu!(scope, index_1 += y0_stride); + gpu!(scope, index_1 += x1_stride); + gpu!(scope, inp_1 = input[index_1]); + gpu!(scope, index_2 = index_base); + gpu!(scope, index_2 += y0_stride); + gpu!(scope, index_2 += x2_stride); + gpu!(scope, inp_2 = input[index_2]); + gpu!(scope, index_3 = index_base); + gpu!(scope, index_3 += y0_stride); + gpu!(scope, index_3 += x3_stride); + gpu!(scope, inp_3 = input[index_3]); + + let coefficients0 = Self::cubic_interp1d(scope, inp_0, inp_1, inp_2, inp_3, xw); + + gpu!(scope, index_0 = index_base); + gpu!(scope, index_0 += y1_stride); + gpu!(scope, index_0 += x0_stride); + gpu!(scope, inp_0 = input[index_0]); + gpu!(scope, index_1 = index_base); + gpu!(scope, index_1 += y1_stride); + gpu!(scope, index_1 += x1_stride); + gpu!(scope, inp_1 = input[index_1]); + gpu!(scope, index_2 = index_base); + gpu!(scope, index_2 += y1_stride); + gpu!(scope, index_2 += x2_stride); + gpu!(scope, inp_2 = input[index_2]); + gpu!(scope, index_3 = index_base); + gpu!(scope, index_3 += y1_stride); + gpu!(scope, index_3 += x3_stride); + gpu!(scope, inp_3 = input[index_3]); + + let coefficients1 = Self::cubic_interp1d(scope, inp_0, inp_1, inp_2, inp_3, xw); + + gpu!(scope, index_0 = index_base); + gpu!(scope, index_0 += y2_stride); + gpu!(scope, index_0 += x0_stride); + gpu!(scope, inp_0 = input[index_0]); + gpu!(scope, index_1 = index_base); + gpu!(scope, index_1 += y2_stride); + gpu!(scope, index_1 += x1_stride); + gpu!(scope, inp_1 = input[index_1]); + gpu!(scope, index_2 = index_base); + gpu!(scope, index_2 += y2_stride); + gpu!(scope, index_2 += x2_stride); + gpu!(scope, inp_2 = input[index_2]); + gpu!(scope, index_3 = index_base); + gpu!(scope, index_3 += y2_stride); + gpu!(scope, index_3 += x3_stride); + gpu!(scope, inp_3 = input[index_3]); + + let coefficients2 = Self::cubic_interp1d(scope, inp_0, inp_1, inp_2, inp_3, xw); + + gpu!(scope, index_0 = index_base); + gpu!(scope, index_0 += y3_stride); + gpu!(scope, index_0 += x0_stride); + gpu!(scope, inp_0 = input[index_0]); + gpu!(scope, index_1 = index_base); + gpu!(scope, index_1 += y3_stride); + gpu!(scope, index_1 += x1_stride); + gpu!(scope, inp_1 = input[index_1]); + gpu!(scope, index_2 = index_base); + gpu!(scope, index_2 += y3_stride); + gpu!(scope, index_2 += x2_stride); + gpu!(scope, inp_2 = input[index_2]); + gpu!(scope, index_3 = index_base); + gpu!(scope, index_3 += y3_stride); + gpu!(scope, index_3 += x3_stride); + gpu!(scope, inp_3 = input[index_3]); + + let coefficients3 = Self::cubic_interp1d(scope, inp_0, inp_1, inp_2, inp_3, xw); + + let val = Self::cubic_interp1d( + scope, + coefficients0, + coefficients1, + coefficients2, + coefficients3, + yw, + ); + + gpu!(scope, output[id] = val); + } + + fn min(scope: &mut Scope, a: Variable, b: Variable) -> Variable { + let cond = scope.create_local(Elem::Bool); + let res = scope.create_local(a.item()); + + gpu!(scope, cond = a < b); + gpu!(scope, if(cond).then(|scope|{ + gpu!(scope, res = a); + }).else(|scope|{ + gpu!(scope, res = b); + })); + + res + } + + fn cubic_interp1d( + scope: &mut Scope, + x0: Variable, + x1: Variable, + x2: Variable, + x3: Variable, + t: Variable, + ) -> Variable { + let item = x0.item(); + let x = scope.create_local(item); + let a: Variable = scope.create_with_value(-0.75, item); + let one: Variable = scope.create_with_value(1, item); + let two: Variable = scope.create_with_value(2, item); + let cubic = scope.create_local(item); + let cubic_tmp = scope.create_local(item); + + gpu!(scope, x = t + one); + let coeffs0 = Self::cubic_convolution2(scope, x, a); + + let coeffs1 = Self::cubic_convolution1(scope, t, a); + + gpu!(scope, x = one - t); + let coeffs2 = Self::cubic_convolution1(scope, x, a); + + gpu!(scope, x = two - t); + let coeffs3 = Self::cubic_convolution2(scope, x, a); + + gpu!(scope, cubic = x0 * coeffs0); + gpu!(scope, cubic_tmp = x1 * coeffs1); + gpu!(scope, cubic += cubic_tmp); + gpu!(scope, cubic_tmp = x2 * coeffs2); + gpu!(scope, cubic += cubic_tmp); + gpu!(scope, cubic_tmp = x3 * coeffs3); + gpu!(scope, cubic += cubic_tmp); + + cubic + } + + fn cubic_convolution1(scope: &mut Scope, x: Variable, a: Variable) -> Variable { + let item = x.item(); + let conv = scope.create_local(item); + let tmp = scope.create_local(item); + let one = scope.create_with_value(1, item); + let two = scope.create_with_value(2, item); + let three = scope.create_with_value(3, item); + + gpu!(scope, conv = a + two); + gpu!(scope, conv *= x); + gpu!(scope, tmp = a + three); + gpu!(scope, conv = conv - tmp); + gpu!(scope, conv *= x); + gpu!(scope, conv *= x); + gpu!(scope, conv += one); + + conv + } + + fn cubic_convolution2(scope: &mut Scope, x: Variable, a: Variable) -> Variable { + let item = x.item(); + let conv = scope.create_local(item); + let tmp = scope.create_local(item); + let four = scope.create_with_value(4, item); + let five = scope.create_with_value(5, item); + let eight = scope.create_with_value(8, item); + + gpu!(scope, conv = a * x); + gpu!(scope, tmp = five * a); + gpu!(scope, conv = conv - tmp); + gpu!(scope, conv *= x); + gpu!(scope, tmp = eight * a); + gpu!(scope, conv += tmp); + gpu!(scope, conv *= x); + gpu!(scope, tmp = four * a); + gpu!(scope, conv = conv - tmp); + + conv + } +} + +impl DynamicKernelSource for InterpolateBicubicEagerKernel { + fn source(&self) -> SourceTemplate { + let mut scope = Scope::root(); + let item = E::gpu_elem().into(); + + let input = Variable::GlobalInputArray(0, item); + let output = Variable::GlobalOutputArray(0, item); + + InterpolateBicubicShader { input, output }.expand(&mut scope); + + scope.write_global_custom(output); + + let input = InputInfo::Array { + item, + visibility: Visibility::Read, + }; + + let out = OutputInfo::Array { item }; + + let info = CompilationInfo { + inputs: vec![input], + outputs: vec![out], + scope, + }; + + let settings = CompilationSettings::default(); + let shader = Compilation::new(info).compile(settings); + let shader = ::compile(shader); + SourceTemplate::new(shader.to_string()) + } + + fn id(&self) -> String { + format!("{:?}", core::any::TypeId::of::()) + } +} + +pub(crate) fn interpolate_bicubic_launch( + input: JitTensor, + output: JitTensor, +) -> JitTensor { + let kernel = InterpolateBicubicEagerKernel::new(); + + execute_dynamic::, u32>( + &[EagerHandle::new( + &input.handle, + &input.strides, + &input.shape.dims, + )], + &[EagerHandle::new( + &output.handle, + &output.strides, + &output.shape.dims, + )], + None, + kernel, + WorkgroupLaunch::Output { pos: 0 }, + input.client, + ); + + output +} diff --git a/crates/burn-jit/src/kernel/interpolate/bilinear.rs b/crates/burn-jit/src/kernel/interpolate/bilinear.rs new file mode 100644 index 0000000000..a2387ab5bf --- /dev/null +++ b/crates/burn-jit/src/kernel/interpolate/bilinear.rs @@ -0,0 +1,249 @@ +use std::marker::PhantomData; + +use crate::{ + codegen::{ + execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo, + OutputInfo, WorkgroupLaunch, + }, + gpu::{gpu, Elem, Scope, Variable, Visibility}, + kernel::{DynamicKernelSource, SourceTemplate}, + tensor::JitTensor, + Compiler, JitElement, Runtime, +}; + +#[derive(new)] +struct InterpolateBilinearEagerKernel { + _runtime: PhantomData, + _elem: PhantomData, +} + +struct InterpolateBilinearShader { + input: Variable, + output: Variable, +} + +impl InterpolateBilinearShader { + pub(crate) fn expand(self, scope: &mut Scope) { + let input = self.input; + let output = self.output; + let id = Variable::Id; + + let input_stride_0 = scope.create_local(Elem::UInt); + let input_stride_1 = scope.create_local(Elem::UInt); + let input_stride_2 = scope.create_local(Elem::UInt); + let input_stride_3 = scope.create_local(Elem::UInt); + + let input_shape_2 = scope.create_local(Elem::UInt); + let input_shape_3 = scope.create_local(Elem::UInt); + + let output_stride_0 = scope.create_local(Elem::UInt); + let output_stride_1 = scope.create_local(Elem::UInt); + let output_stride_2 = scope.create_local(Elem::UInt); + let output_stride_3 = scope.create_local(Elem::UInt); + + let output_shape_0 = scope.create_local(Elem::UInt); + let output_shape_1 = scope.create_local(Elem::UInt); + let output_shape_2 = scope.create_local(Elem::UInt); + let output_shape_3 = scope.create_local(Elem::UInt); + + gpu!(scope, input_stride_0 = stride(input, 0u32)); + gpu!(scope, input_stride_1 = stride(input, 1u32)); + gpu!(scope, input_stride_2 = stride(input, 2u32)); + gpu!(scope, input_stride_3 = stride(input, 3u32)); + + gpu!(scope, input_shape_2 = shape(input, 2u32)); + gpu!(scope, input_shape_3 = shape(input, 3u32)); + + gpu!(scope, output_stride_0 = stride(output, 0u32)); + gpu!(scope, output_stride_1 = stride(output, 1u32)); + gpu!(scope, output_stride_2 = stride(output, 2u32)); + gpu!(scope, output_stride_3 = stride(output, 3u32)); + + gpu!(scope, output_shape_0 = shape(output, 0u32)); + gpu!(scope, output_shape_1 = shape(output, 1u32)); + gpu!(scope, output_shape_2 = shape(output, 2u32)); + gpu!(scope, output_shape_3 = shape(output, 3u32)); + + let b = scope.create_local(Elem::UInt); + let c = scope.create_local(Elem::UInt); + let h = scope.create_local(Elem::UInt); + let w = scope.create_local(Elem::UInt); + + gpu!(scope, b = id / output_stride_0); + gpu!(scope, b = b % output_shape_0); + + gpu!(scope, c = id / output_stride_1); + gpu!(scope, c = c % output_shape_1); + + gpu!(scope, h = id / output_stride_2); + gpu!(scope, h = h % output_shape_2); + + gpu!(scope, w = id / output_stride_3); + gpu!(scope, w = w % output_shape_3); + + let factor_float = scope.create_local(input.item()); + let numerator_float = scope.create_local(input.item()); + let numerator_int = scope.create_local(Elem::UInt); + let denominator_float = scope.create_local(input.item()); + let denominator_int = scope.create_local(Elem::UInt); + + let frac = scope.create_local(input.item()); + let v0 = scope.create_local(input.item()); + let v1 = scope.create_local(input.item()); + let one = scope.create_with_value(1f32, input.item()); + + let y0 = scope.create_local(Elem::UInt); + let y1 = scope.create_local(Elem::UInt); + let yw = scope.create_local(input.item()); + let yw_ = scope.create_local(input.item()); + + let x0 = scope.create_local(Elem::UInt); + let x1 = scope.create_local(Elem::UInt); + let xw = scope.create_local(input.item()); + let xw_ = scope.create_local(input.item()); + + gpu!(scope, numerator_int = input_shape_2 - 1u32); + gpu!(scope, denominator_int = output_shape_2 - 1u32); + gpu!(scope, factor_float = cast(h)); + gpu!(scope, numerator_float = cast(numerator_int)); + gpu!(scope, denominator_float = cast(denominator_int)); + gpu!(scope, frac = factor_float * numerator_float); + gpu!(scope, frac = frac / denominator_float); + gpu!(scope, v0 = floor(frac)); + gpu!(scope, v1 = ceil(frac)); + gpu!(scope, yw = frac - v0); + gpu!(scope, yw_ = one - yw); + gpu!(scope, y0 = cast(v0)); + gpu!(scope, y1 = cast(v1)); + + gpu!(scope, numerator_int = input_shape_3 - 1u32); + gpu!(scope, denominator_int = output_shape_3 - 1u32); + gpu!(scope, factor_float = cast(w)); + gpu!(scope, numerator_float = cast(numerator_int)); + gpu!(scope, denominator_float = cast(denominator_int)); + gpu!(scope, frac = factor_float * numerator_float); + gpu!(scope, frac = frac / denominator_float); + gpu!(scope, v0 = floor(frac)); + gpu!(scope, v1 = ceil(frac)); + gpu!(scope, xw = frac - v0); + gpu!(scope, xw_ = one - xw); + gpu!(scope, x0 = cast(v0)); + gpu!(scope, x1 = cast(v1)); + + let index_base = scope.create_local(Elem::UInt); + let index_tmp = scope.create_local(Elem::UInt); + let index = scope.create_local(Elem::UInt); + let y0_stride = scope.create_local(Elem::UInt); + let y1_stride = scope.create_local(Elem::UInt); + let x0_stride = scope.create_local(Elem::UInt); + let x1_stride = scope.create_local(Elem::UInt); + let p_a = scope.create_local(input.item()); + let p_b = scope.create_local(input.item()); + let p_c = scope.create_local(input.item()); + let p_d = scope.create_local(input.item()); + + gpu!(scope, index_base = b * input_stride_0); + gpu!(scope, index_tmp = c * input_stride_1); + gpu!(scope, index_base += index_tmp); + gpu!(scope, y0_stride = y0 * input_stride_2); + gpu!(scope, y1_stride = y1 * input_stride_2); + gpu!(scope, x0_stride = x0 * input_stride_3); + gpu!(scope, x1_stride = x1 * input_stride_3); + + gpu!(scope, index = index_base); + gpu!(scope, index += y0_stride); + gpu!(scope, index += x0_stride); + gpu!(scope, p_a = input[index]); + gpu!(scope, p_a *= xw_); + gpu!(scope, p_a *= yw_); + + gpu!(scope, index = index_base); + gpu!(scope, index += y0_stride); + gpu!(scope, index += x1_stride); + gpu!(scope, p_b = input[index]); + gpu!(scope, p_b *= xw); + gpu!(scope, p_b *= yw_); + + gpu!(scope, index = index_base); + gpu!(scope, index += y1_stride); + gpu!(scope, index += x0_stride); + gpu!(scope, p_c = input[index]); + gpu!(scope, p_c *= xw_); + gpu!(scope, p_c *= yw); + + gpu!(scope, index = index_base); + gpu!(scope, index += y1_stride); + gpu!(scope, index += x1_stride); + gpu!(scope, p_d = input[index]); + gpu!(scope, p_d *= xw); + gpu!(scope, p_d *= yw); + + let sum = scope.create_local(input.item()); + gpu!(scope, sum = p_a + p_b); + gpu!(scope, sum += p_c); + gpu!(scope, sum += p_d); + gpu!(scope, output[id] = sum); + } +} + +impl DynamicKernelSource for InterpolateBilinearEagerKernel { + fn source(&self) -> SourceTemplate { + let mut scope = Scope::root(); + let item = E::gpu_elem().into(); + + let input = Variable::GlobalInputArray(0, item); + let output = Variable::GlobalOutputArray(0, item); + + InterpolateBilinearShader { input, output }.expand(&mut scope); + + scope.write_global_custom(output); + + let input = InputInfo::Array { + item, + visibility: Visibility::Read, + }; + + let out = OutputInfo::Array { item }; + + let info = CompilationInfo { + inputs: vec![input], + outputs: vec![out], + scope, + }; + + let settings = CompilationSettings::default(); + let shader = Compilation::new(info).compile(settings); + let shader = ::compile(shader); + SourceTemplate::new(shader.to_string()) + } + + fn id(&self) -> String { + format!("{:?}", core::any::TypeId::of::()) + } +} + +pub(crate) fn interpolate_bilinear_launch( + input: JitTensor, + output: JitTensor, +) -> JitTensor { + let kernel = InterpolateBilinearEagerKernel::new(); + + execute_dynamic::, u32>( + &[EagerHandle::new( + &input.handle, + &input.strides, + &input.shape.dims, + )], + &[EagerHandle::new( + &output.handle, + &output.strides, + &output.shape.dims, + )], + None, + kernel, + WorkgroupLaunch::Output { pos: 0 }, + input.client, + ); + + output +} diff --git a/crates/burn-jit/src/kernel/interpolate/mod.rs b/crates/burn-jit/src/kernel/interpolate/mod.rs new file mode 100644 index 0000000000..19e90af285 --- /dev/null +++ b/crates/burn-jit/src/kernel/interpolate/mod.rs @@ -0,0 +1,7 @@ +mod base; +mod bicubic; +mod bilinear; +mod nearest; +mod nearest_backward; + +pub use base::*; diff --git a/crates/burn-jit/src/kernel/interpolate/nearest.rs b/crates/burn-jit/src/kernel/interpolate/nearest.rs new file mode 100644 index 0000000000..6197a0e098 --- /dev/null +++ b/crates/burn-jit/src/kernel/interpolate/nearest.rs @@ -0,0 +1,185 @@ +use std::marker::PhantomData; + +use crate::{ + codegen::{ + execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo, + OutputInfo, WorkgroupLaunch, + }, + gpu::{gpu, Elem, Scope, Variable, Visibility}, + kernel::{DynamicKernelSource, SourceTemplate}, + tensor::JitTensor, + Compiler, JitElement, Runtime, +}; + +#[derive(new)] +struct InterpolateNearestEagerKernel { + _runtime: PhantomData, + _elem: PhantomData, +} + +struct InterpolateNearestShader { + input: Variable, + output: Variable, +} + +impl InterpolateNearestShader { + pub(crate) fn expand(self, scope: &mut Scope) { + let input = self.input; + let output = self.output; + let id = Variable::Id; + + let input_stride_0 = scope.create_local(Elem::UInt); + let input_stride_1 = scope.create_local(Elem::UInt); + let input_stride_2 = scope.create_local(Elem::UInt); + let input_stride_3 = scope.create_local(Elem::UInt); + + let input_shape_2 = scope.create_local(Elem::UInt); + let input_shape_3 = scope.create_local(Elem::UInt); + + let output_stride_0 = scope.create_local(Elem::UInt); + let output_stride_1 = scope.create_local(Elem::UInt); + let output_stride_2 = scope.create_local(Elem::UInt); + let output_stride_3 = scope.create_local(Elem::UInt); + + let output_shape_0 = scope.create_local(Elem::UInt); + let output_shape_1 = scope.create_local(Elem::UInt); + let output_shape_2 = scope.create_local(Elem::UInt); + let output_shape_3 = scope.create_local(Elem::UInt); + + gpu!(scope, input_stride_0 = stride(input, 0u32)); + gpu!(scope, input_stride_1 = stride(input, 1u32)); + gpu!(scope, input_stride_2 = stride(input, 2u32)); + gpu!(scope, input_stride_3 = stride(input, 3u32)); + + gpu!(scope, input_shape_2 = shape(input, 2u32)); + gpu!(scope, input_shape_3 = shape(input, 3u32)); + + gpu!(scope, output_stride_0 = stride(output, 0u32)); + gpu!(scope, output_stride_1 = stride(output, 1u32)); + gpu!(scope, output_stride_2 = stride(output, 2u32)); + gpu!(scope, output_stride_3 = stride(output, 3u32)); + + gpu!(scope, output_shape_0 = shape(output, 0u32)); + gpu!(scope, output_shape_1 = shape(output, 1u32)); + gpu!(scope, output_shape_2 = shape(output, 2u32)); + gpu!(scope, output_shape_3 = shape(output, 3u32)); + + let b = scope.create_local(Elem::UInt); + let c = scope.create_local(Elem::UInt); + let h = scope.create_local(Elem::UInt); + let w = scope.create_local(Elem::UInt); + + gpu!(scope, b = id / output_stride_0); + gpu!(scope, b = b % output_shape_0); + + gpu!(scope, c = id / output_stride_1); + gpu!(scope, c = c % output_shape_1); + + gpu!(scope, h = id / output_stride_2); + gpu!(scope, h = h % output_shape_2); + + gpu!(scope, w = id / output_stride_3); + gpu!(scope, w = w % output_shape_3); + + let factor_float = scope.create_local(Elem::Float); + let numerator_float = scope.create_local(Elem::Float); + let denominator_float = scope.create_local(Elem::Float); + let x = scope.create_local(Elem::Float); + let y = scope.create_local(Elem::Float); + let xu = scope.create_local(Elem::UInt); + let yu = scope.create_local(Elem::UInt); + + gpu!(scope, factor_float = cast(h)); + gpu!(scope, numerator_float = cast(input_shape_2)); + gpu!(scope, denominator_float = cast(output_shape_2)); + gpu!(scope, y = factor_float * numerator_float); + gpu!(scope, y = y / denominator_float); + gpu!(scope, y = floor(y)); + gpu!(scope, yu = cast(y)); + + gpu!(scope, factor_float = cast(w)); + gpu!(scope, numerator_float = cast(input_shape_3)); + gpu!(scope, denominator_float = cast(output_shape_3)); + gpu!(scope, x = factor_float * numerator_float); + gpu!(scope, x = x / denominator_float); + gpu!(scope, x = floor(x)); + gpu!(scope, xu = cast(x)); + + let index = scope.create_local(Elem::UInt); + let index_tmp = scope.create_local(Elem::UInt); + let val = scope.create_local(output.item()); + + gpu!(scope, index = b * input_stride_0); + gpu!(scope, index_tmp = c * input_stride_1); + gpu!(scope, index += index_tmp); + gpu!(scope, index_tmp = yu * input_stride_2); + gpu!(scope, index += index_tmp); + gpu!(scope, index_tmp = xu * input_stride_3); + gpu!(scope, index += index_tmp); + + gpu!(scope, val = input[index]); + gpu!(scope, output[id] = val); + } +} + +impl DynamicKernelSource for InterpolateNearestEagerKernel { + fn source(&self) -> SourceTemplate { + let mut scope = Scope::root(); + let item = E::gpu_elem().into(); + + let input = Variable::GlobalInputArray(0, item); + let output = Variable::GlobalOutputArray(0, item); + + InterpolateNearestShader { input, output }.expand(&mut scope); + + scope.write_global_custom(output); + + let input = InputInfo::Array { + item, + visibility: Visibility::Read, + }; + + let out = OutputInfo::Array { item }; + + let info = CompilationInfo { + inputs: vec![input], + outputs: vec![out], + scope, + }; + + let settings = CompilationSettings::default(); + let shader = Compilation::new(info).compile(settings); + let shader = ::compile(shader); + SourceTemplate::new(shader.to_string()) + } + + fn id(&self) -> String { + format!("{:?}", core::any::TypeId::of::()) + } +} + +pub(crate) fn interpolate_nearest_launch( + input: JitTensor, + output: JitTensor, +) -> JitTensor { + let kernel = InterpolateNearestEagerKernel::new(); + + execute_dynamic::, u32>( + &[EagerHandle::new( + &input.handle, + &input.strides, + &input.shape.dims, + )], + &[EagerHandle::new( + &output.handle, + &output.strides, + &output.shape.dims, + )], + None, + kernel, + WorkgroupLaunch::Output { pos: 0 }, + input.client, + ); + + output +} diff --git a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs new file mode 100644 index 0000000000..5d6ec08c93 --- /dev/null +++ b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs @@ -0,0 +1,244 @@ +use std::marker::PhantomData; + +use crate::{ + codegen::{ + execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo, + OutputInfo, WorkgroupLaunch, + }, + gpu::{gpu, Elem, Scope, Variable, Visibility}, + kernel::{DynamicKernelSource, SourceTemplate}, + tensor::JitTensor, + Compiler, JitElement, Runtime, +}; + +#[derive(new)] +struct InterpolateNearestBackwardEagerKernel { + _runtime: PhantomData, + _elem: PhantomData, +} + +struct InterpolateNearestBackwardShader { + out_grad: Variable, + output: Variable, +} + +impl InterpolateNearestBackwardShader { + fn expand(self, scope: &mut Scope) { + let grad = self.out_grad; + let output = self.output; + let id = Variable::Id; + + let grad_stride_0 = scope.create_local(Elem::UInt); + let grad_stride_1 = scope.create_local(Elem::UInt); + let grad_stride_2 = scope.create_local(Elem::UInt); + let grad_stride_3 = scope.create_local(Elem::UInt); + + let grad_shape_0 = scope.create_local(Elem::UInt); + let grad_shape_1 = scope.create_local(Elem::UInt); + let grad_shape_2 = scope.create_local(Elem::UInt); + let grad_shape_3 = scope.create_local(Elem::UInt); + + let output_stride_0 = scope.create_local(Elem::UInt); + let output_stride_1 = scope.create_local(Elem::UInt); + let output_stride_2 = scope.create_local(Elem::UInt); + let output_stride_3 = scope.create_local(Elem::UInt); + + let output_shape_0 = scope.create_local(Elem::UInt); + let output_shape_1 = scope.create_local(Elem::UInt); + let output_shape_2 = scope.create_local(Elem::UInt); + let output_shape_3 = scope.create_local(Elem::UInt); + + gpu!(scope, grad_stride_0 = stride(grad, 0u32)); + gpu!(scope, grad_stride_1 = stride(grad, 1u32)); + gpu!(scope, grad_stride_2 = stride(grad, 2u32)); + gpu!(scope, grad_stride_3 = stride(grad, 3u32)); + + gpu!(scope, grad_shape_0 = shape(grad, 0u32)); + gpu!(scope, grad_shape_1 = shape(grad, 1u32)); + gpu!(scope, grad_shape_2 = shape(grad, 2u32)); + gpu!(scope, grad_shape_3 = shape(grad, 3u32)); + + gpu!(scope, output_stride_0 = stride(output, 0u32)); + gpu!(scope, output_stride_1 = stride(output, 1u32)); + gpu!(scope, output_stride_2 = stride(output, 2u32)); + gpu!(scope, output_stride_3 = stride(output, 3u32)); + + gpu!(scope, output_shape_0 = shape(output, 0u32)); + gpu!(scope, output_shape_1 = shape(output, 1u32)); + gpu!(scope, output_shape_2 = shape(output, 2u32)); + gpu!(scope, output_shape_3 = shape(output, 3u32)); + + let b = scope.create_local(Elem::UInt); + let c = scope.create_local(Elem::UInt); + let oh = scope.create_local(Elem::UInt); + let ow = scope.create_local(Elem::UInt); + + gpu!(scope, b = id / output_stride_0); + gpu!(scope, b = b % output_shape_0); + + gpu!(scope, c = id / output_stride_1); + gpu!(scope, c = c % output_shape_1); + + gpu!(scope, oh = id / output_stride_2); + gpu!(scope, oh = oh % output_shape_2); + + gpu!(scope, ow = id / output_stride_3); + gpu!(scope, ow = ow % output_shape_3); + + let gh_start = Self::start_index(scope, oh, grad_shape_2, output_shape_2); + let gh_end = Self::end_index(scope, oh, grad_shape_2, output_shape_2); + let gw_start = Self::start_index(scope, ow, grad_shape_3, output_shape_3); + let gw_end = Self::end_index(scope, ow, grad_shape_3, output_shape_3); + + let result = scope.create_local(grad.item()); + + let index_grad = scope.create_local(Elem::UInt); + let index_grad_0 = scope.create_local(Elem::UInt); + let index_grad_1 = scope.create_local(Elem::UInt); + let index_grad_2 = scope.create_local(Elem::UInt); + let index_grad_3 = scope.create_local(Elem::UInt); + + gpu!(scope, index_grad_0 = b * grad_stride_0); + gpu!(scope, index_grad_1 = c * grad_stride_1); + + let sum = scope.zero(output.item()); + + gpu!( + scope, + range(gh_start, gh_end).for_each(|gh, scope| { + gpu!( + scope, + range(gw_start, gw_end).for_each(|gw, scope| { + gpu!(scope, index_grad_2 = gh * grad_stride_2); + gpu!(scope, index_grad_3 = gw * grad_stride_3); + + gpu!(scope, index_grad = index_grad_0); + gpu!(scope, index_grad += index_grad_1); + gpu!(scope, index_grad += index_grad_2); + gpu!(scope, index_grad += index_grad_3); + + gpu!(scope, result = grad[index_grad]); + + gpu!(scope, sum += result); + }) + ); + }) + ); + + gpu!(scope, output[id] = sum); + } + + fn start_index( + scope: &mut Scope, + input_index: Variable, + output_size: Variable, + input_size: Variable, + ) -> Variable { + let numerator_float = scope.create_local(Elem::Float); + let div = scope.create_local(Elem::Float); + let index = scope.create_local(Elem::UInt); + + gpu!(scope, index = input_index * output_size); + gpu!(scope, numerator_float = cast(index)); + gpu!(scope, div = cast(input_size)); + gpu!(scope, div = numerator_float / div); + gpu!(scope, div = ceil(div)); + gpu!(scope, index = cast(div)); + + index + } + + fn end_index( + scope: &mut Scope, + input_index: Variable, + output_size: Variable, + input_size: Variable, + ) -> Variable { + let numerator_float = scope.create_local(Elem::Float); + let div = scope.create_local(Elem::Float); + let index = scope.create_local(Elem::UInt); + let min = scope.create_local(Elem::Bool); + let end_index = scope.create_local(Elem::UInt); + + gpu!(scope, index = input_index + 1u32); + gpu!(scope, index *= output_size); + gpu!(scope, numerator_float = cast(index)); + gpu!(scope, div = cast(input_size)); + gpu!(scope, div = numerator_float / div); + gpu!(scope, div = ceil(div)); + gpu!(scope, index = cast(div)); + + gpu!(scope, min = output_size < index); + gpu!(scope, if(min).then(|scope|{ + gpu!(scope, end_index = output_size); + }).else(|scope|{ + gpu!(scope, end_index = index); + })); + + end_index + } +} + +impl DynamicKernelSource + for InterpolateNearestBackwardEagerKernel +{ + fn source(&self) -> SourceTemplate { + let mut scope = Scope::root(); + let item = E::gpu_elem().into(); + + let out_grad = Variable::GlobalInputArray(0, item); + let output = Variable::GlobalOutputArray(0, item); + + InterpolateNearestBackwardShader { out_grad, output }.expand(&mut scope); + + scope.write_global_custom(output); + + let input = InputInfo::Array { + item, + visibility: Visibility::Read, + }; + + let out = OutputInfo::Array { item }; + + let info = CompilationInfo { + inputs: vec![input], + outputs: vec![out], + scope, + }; + + let settings = CompilationSettings::default(); + let shader = Compilation::new(info).compile(settings); + let shader = ::compile(shader); + SourceTemplate::new(shader.to_string()) + } + + fn id(&self) -> String { + format!("{:?}", core::any::TypeId::of::()) + } +} + +pub(crate) fn interpolate_nearest_backward_launch( + out_grad: JitTensor, + output: JitTensor, +) -> JitTensor { + let kernel = InterpolateNearestBackwardEagerKernel::new(); + + execute_dynamic::, u32>( + &[EagerHandle::new( + &out_grad.handle, + &out_grad.strides, + &out_grad.shape.dims, + )], + &[EagerHandle::new( + &output.handle, + &output.strides, + &output.shape.dims, + )], + None, + kernel, + WorkgroupLaunch::Output { pos: 0 }, + out_grad.client, + ); + + output +} diff --git a/crates/burn-jit/src/template/interpolate/bicubic.wgsl b/crates/burn-jit/src/template/interpolate/bicubic.wgsl deleted file mode 100644 index 81604c2f54..0000000000 --- a/crates/burn-jit/src/template/interpolate/bicubic.wgsl +++ /dev/null @@ -1,113 +0,0 @@ -@group(0) -@binding(0) -var input: array<{{ elem }}>; - -@group(0) -@binding(1) -var output: array<{{ elem }}>; - -@group(0) -@binding(2) -var info: array; - -const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; - -fn cubic_convolution1(x: f32, a: f32) -> f32 { - return ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0; -} - -fn cubic_convolution2(x: f32, a: f32) -> f32 { - return ((a * x - 5.0 * a) * x + 8.0 * a) * x - 4.0 * a; -} - -fn cubic_interp1d(x0: f32, x1: f32, x2: f32, x3: f32, t: f32) -> f32 { - let coeffs0 = cubic_convolution2(t + 1.0, -0.75); - let coeffs1 = cubic_convolution1(t, -0.75); - let coeffs2 = cubic_convolution1(1.0 - t, -0.75); - let coeffs3 = cubic_convolution2(2.0 - t, -0.75); - return x0 * coeffs0 + x1 * coeffs1 + x2 * coeffs2 + x3 * coeffs3; -} - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; - - let input_stride_0 = info[1]; - let input_stride_1 = info[2]; - let input_stride_2 = info[3]; - let input_stride_3 = info[4]; - let output_stride_0 = info[5]; - let output_stride_1 = info[6]; - let output_stride_2 = info[7]; - let output_stride_3 = info[8]; - - let input_shape_0 = info[9]; - let input_shape_1 = info[10]; - let input_shape_2 = info[11]; - let input_shape_3 = info[12]; - let output_shape_0 = info[13]; - let output_shape_1 = info[14]; - let output_shape_2 = info[15]; - let output_shape_3 = info[16]; - - let b = id / output_stride_0 % output_shape_0; - let c = id / output_stride_1 % output_shape_1; - let h = id / output_stride_2 % output_shape_2; - let w = id / output_stride_3 % output_shape_3; - - let input_height = f32(input_shape_2 - 1u); - let y_frac = f32(h) * input_height / f32(output_shape_2 - 1u); - let y_in = floor(y_frac); - let yw = y_frac - y_in; - - let y0 = u32(max(y_in - 1.0, 0.0)); - let y1 = u32(y_in); - let y2 = u32(min(y_in + 1.0, input_height)); - let y3 = u32(min(y_in + 2.0, input_height)); - - let input_width = f32(input_shape_3 - 1u); - let x_frac = f32(w) * input_width / f32(output_shape_3 - 1u); - let x_in = floor(x_frac); - let xw = x_frac - x_in; - - let x0 = u32(max(x_in - 1.0, 0.0)); - let x1 = u32(x_in); - let x2 = u32(min(x_in + 1.0, input_width)); - let x3 = u32(min(x_in + 2.0, input_width)); - - let coefficients0 = cubic_interp1d( - input[b * input_stride_0 + c * input_stride_1 + y0 * input_stride_2 + x0 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y0 * input_stride_2 + x1 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y0 * input_stride_2 + x2 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y0 * input_stride_2 + x3 * input_stride_3], - xw, - ); - let coefficients1 = cubic_interp1d( - input[b * input_stride_0 + c * input_stride_1 + y1 * input_stride_2 + x0 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y1 * input_stride_2 + x1 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y1 * input_stride_2 + x2 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y1 * input_stride_2 + x3 * input_stride_3], - xw, - ); - let coefficients2 = cubic_interp1d( - input[b * input_stride_0 + c * input_stride_1 + y2 * input_stride_2 + x0 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y2 * input_stride_2 + x1 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y2 * input_stride_2 + x2 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y2 * input_stride_2 + x3 * input_stride_3], - xw, - ); - let coefficients3 = cubic_interp1d( - input[b * input_stride_0 + c * input_stride_1 + y3 * input_stride_2 + x0 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y3 * input_stride_2 + x1 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y3 * input_stride_2 + x2 * input_stride_3], - input[b * input_stride_0 + c * input_stride_1 + y3 * input_stride_2 + x3 * input_stride_3], - xw, - ); - - let val = cubic_interp1d(coefficients0, coefficients1, coefficients2, coefficients3, yw); - output[id] = val; -} diff --git a/crates/burn-jit/src/template/interpolate/bilinear.wgsl b/crates/burn-jit/src/template/interpolate/bilinear.wgsl deleted file mode 100644 index a8f4441725..0000000000 --- a/crates/burn-jit/src/template/interpolate/bilinear.wgsl +++ /dev/null @@ -1,72 +0,0 @@ -@group(0) -@binding(0) -var input: array<{{ elem }}>; - -@group(0) -@binding(1) -var output: array<{{ elem }}>; - -@group(0) -@binding(2) -var info: array; - -const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; - - let input_stride_0 = info[1]; - let input_stride_1 = info[2]; - let input_stride_2 = info[3]; - let input_stride_3 = info[4]; - let output_stride_0 = info[5]; - let output_stride_1 = info[6]; - let output_stride_2 = info[7]; - let output_stride_3 = info[8]; - - let input_shape_0 = info[9]; - let input_shape_1 = info[10]; - let input_shape_2 = info[11]; - let input_shape_3 = info[12]; - let output_shape_0 = info[13]; - let output_shape_1 = info[14]; - let output_shape_2 = info[15]; - let output_shape_3 = info[16]; - - let b = id / output_stride_0 % output_shape_0; - let c = id / output_stride_1 % output_shape_1; - let h = id / output_stride_2 % output_shape_2; - let w = id / output_stride_3 % output_shape_3; - - let y_frac = f32(h) * f32(input_shape_2 - 1u) / f32(output_shape_2 - 1u); - let y0 = floor(y_frac); - let y1 = ceil(y_frac); - let yw = y_frac - y0; - - let x_frac = f32(w) * f32(input_shape_3 - 1u) / f32(output_shape_3 - 1u); - let x0 = floor(x_frac); - let x1 = ceil(x_frac); - let xw = x_frac - x0; - - let x0u = u32(x0); - let x1u = u32(x1); - let y0u = u32(y0); - let y1u = u32(y1); - - let p_a = input[b * input_stride_0 + c * input_stride_1 + y0u * input_stride_2 + x0u * input_stride_3]; - let p_b = input[b * input_stride_0 + c * input_stride_1 + y0u * input_stride_2 + x1u * input_stride_3]; - let p_c = input[b * input_stride_0 + c * input_stride_1 + y1u * input_stride_2 + x0u * input_stride_3]; - let p_d = input[b * input_stride_0 + c * input_stride_1 + y1u * input_stride_2 + x1u * input_stride_3]; - - let pa = p_a * (1.0 - xw) * (1.0 - yw); - let pb = p_b * xw * (1.0 - yw); - let pc = p_c * (1.0 - xw) * yw; - let pd = p_d * xw * yw; - - output[id] = pa + pb + pc + pd; -} diff --git a/crates/burn-jit/src/template/interpolate/nearest.wgsl b/crates/burn-jit/src/template/interpolate/nearest.wgsl deleted file mode 100644 index a6f26572f1..0000000000 --- a/crates/burn-jit/src/template/interpolate/nearest.wgsl +++ /dev/null @@ -1,54 +0,0 @@ -@group(0) -@binding(0) -var input: array<{{ elem }}>; - -@group(0) -@binding(1) -var output: array<{{ elem }}>; - -@group(0) -@binding(2) -var info: array; - -const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; - - let input_stride_0 = info[1]; - let input_stride_1 = info[2]; - let input_stride_2 = info[3]; - let input_stride_3 = info[4]; - let output_stride_0 = info[5]; - let output_stride_1 = info[6]; - let output_stride_2 = info[7]; - let output_stride_3 = info[8]; - - let input_shape_0 = info[9]; - let input_shape_1 = info[10]; - let input_shape_2 = info[11]; - let input_shape_3 = info[12]; - let output_shape_0 = info[13]; - let output_shape_1 = info[14]; - let output_shape_2 = info[15]; - let output_shape_3 = info[16]; - - let b = id / output_stride_0 % output_shape_0; - let c = id / output_stride_1 % output_shape_1; - let h = id / output_stride_2 % output_shape_2; - let w = id / output_stride_3 % output_shape_3; - - let y = f32(h) * f32(input_shape_2) / f32(output_shape_2); - let x = f32(w) * f32(input_shape_3) / f32(output_shape_3); - - let xu = u32(floor(x)); - let yu = u32(floor(y)); - - let val = input[b * input_stride_0 + c * input_stride_1 + yu * input_stride_2 + xu * input_stride_3]; - output[id] = val; -} diff --git a/crates/burn-jit/src/template/interpolate/nearest_backward.wgsl b/crates/burn-jit/src/template/interpolate/nearest_backward.wgsl deleted file mode 100644 index 8a3ae6b239..0000000000 --- a/crates/burn-jit/src/template/interpolate/nearest_backward.wgsl +++ /dev/null @@ -1,71 +0,0 @@ -@group(0) -@binding(0) -var grad: array<{{ elem }}>; - -@group(0) -@binding(1) -var output: array<{{ elem }}>; - -@group(0) -@binding(2) -var info: array; - -const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; - - let output_stride_0 = info[1]; - let output_stride_1 = info[2]; - let output_stride_2 = info[3]; - let output_stride_3 = info[4]; - let grad_stride_0 = info[5]; - let grad_stride_1 = info[6]; - let grad_stride_2 = info[7]; - let grad_stride_3 = info[8]; - - let output_shape_0 = info[9]; - let output_shape_1 = info[10]; - let output_shape_2 = info[11]; - let output_shape_3 = info[12]; - let grad_shape_0 = info[13]; - let grad_shape_1 = info[14]; - let grad_shape_2 = info[15]; - let grad_shape_3 = info[16]; - - let b = (id / output_stride_0) % output_shape_0; - let c = (id / output_stride_1) % output_shape_1; - let oh = (id / output_stride_2) % output_shape_2; - let ow = (id / output_stride_3) % output_shape_3; - - let gh_start = start_index(oh, grad_shape_2, output_shape_2); - let gh_end = end_index(oh, grad_shape_2, output_shape_2); - - let gw_start = start_index(ow, grad_shape_3, output_shape_3); - let gw_end = end_index(ow, grad_shape_3, output_shape_3); - - var grad_acc = 0.0; - - for (var gh = gh_start; gh < gh_end; gh++) { - for (var gw = gw_start; gw < gw_end; gw++) { - let index = b * grad_stride_0 + c * grad_stride_1 + gh * grad_stride_2 + gw * grad_stride_3; - grad_acc += grad[index]; - } - } - - output[id] = grad_acc; -} - -fn start_index(input_index: u32, output_size: u32, input_size: u32) -> u32 { - return u32(ceil(f32(input_index) * (f32(output_size) / f32(input_size)))); -} - -fn end_index(input_index: u32, output_size: u32, input_size: u32) -> u32 { - let index = u32(ceil(f32(input_index + 1u) * (f32(output_size) / f32(input_size)))); - return min(index, output_size); -}