From f9b9e98f1539b819c098d57a97d34615c0192834 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 6 Mar 2025 13:07:18 -0500 Subject: [PATCH] Fix fallback --- .../src/reduce/optimization.rs | 12 +- crates/burn-cubecl-fusion/src/reduce/tune.rs | 6 +- crates/burn-cubecl/src/fusion.rs | 103 +++++++++++++----- 3 files changed, 88 insertions(+), 33 deletions(-) diff --git a/crates/burn-cubecl-fusion/src/reduce/optimization.rs b/crates/burn-cubecl-fusion/src/reduce/optimization.rs index ba22eb0977..1820431708 100644 --- a/crates/burn-cubecl-fusion/src/reduce/optimization.rs +++ b/crates/burn-cubecl-fusion/src/reduce/optimization.rs @@ -52,6 +52,8 @@ pub trait ReduceFallbackFn: Send + Sync { input_handle: CubeFusionHandle, shape: &[usize], axis: usize, + inst: &ReduceInstruction, + dtype_out: &DType, ) -> CubeFusionHandle; } @@ -217,9 +219,13 @@ impl ReduceOptimization { let input_handle = context .handles .get_handle(&input.id, &TensorStatus::ReadOnly); - let out_handle = self - .fallback - .run(input_handle, &input.shape, self.reduce.op.axis); + let out_handle = self.fallback.run( + input_handle, + &input.shape, + self.reduce.op.axis, + &self.reduce.inst, + &self.reduce.op.out.dtype, + ); (out_handle, out) }; diff --git a/crates/burn-cubecl-fusion/src/reduce/tune.rs b/crates/burn-cubecl-fusion/src/reduce/tune.rs index 29932c506c..5b5587f86e 100644 --- a/crates/burn-cubecl-fusion/src/reduce/tune.rs +++ b/crates/burn-cubecl-fusion/src/reduce/tune.rs @@ -94,9 +94,11 @@ fn tune_reduce_shared_plane( let context = input.context(); match context { - TuneContext::Original(context) => optimization.execute_fused_reduce::(context), + TuneContext::Original(context) => { + optimization.execute_fused_reduce_shared_plane::(context) + } TuneContext::Fork(mut context_owned) => { - optimization.execute_fused_reduce::(&mut context_owned.as_context()) + optimization.execute_fused_reduce_shared_plane::(&mut context_owned.as_context()) } } .map_err(|e| format!("{e:?}")) diff --git a/crates/burn-cubecl/src/fusion.rs b/crates/burn-cubecl/src/fusion.rs index 66eee1a03d..f179bc411e 100644 --- a/crates/burn-cubecl/src/fusion.rs +++ b/crates/burn-cubecl/src/fusion.rs @@ -1,3 +1,4 @@ +use crate::element::CubeElement; use crate::BoolElement; use crate::{kernel, tensor::CubeTensor, CubeBackend, CubeRuntime, FloatElement, IntElement}; @@ -6,15 +7,19 @@ use burn_cubecl_fusion::matmul::builder::MatmulBuilder; use burn_cubecl_fusion::matmul::optimization::MatmulOptimization; use burn_cubecl_fusion::matmul::MatmulFallbackFn; use burn_cubecl_fusion::reduce::builder::ReduceBuilder; -use burn_cubecl_fusion::reduce::optimization::{ReduceFallbackFn, ReduceOptimization}; +use burn_cubecl_fusion::reduce::optimization::{ + ReduceFallbackFn, ReduceInstruction, ReduceOptimization, +}; use burn_cubecl_fusion::CubeFusionHandle; use burn_cubecl_fusion::{ elemwise::builder::ElementWiseBuilder, CubeOptimization, CubeOptimizationState, }; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_ir::{BackendIr, TensorHandle}; -use burn_tensor::Shape; +use burn_tensor::{DType, Shape}; use core::marker::PhantomData; +use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum}; +use cubecl::reduce::Reduce; use half::{bf16, f16}; use std::sync::Arc; @@ -76,10 +81,10 @@ impl MatmulFallbackFn for FallbackMatmul { rhs: (CubeFusionHandle, &[usize]), ) -> CubeFusionHandle { match lhs.0.dtype { - burn_tensor::DType::F64 => run_fallback_matmul::(lhs, rhs), - burn_tensor::DType::F32 => run_fallback_matmul::(lhs, rhs), - burn_tensor::DType::F16 => run_fallback_matmul::(lhs, rhs), - burn_tensor::DType::BF16 => run_fallback_matmul::(lhs, rhs), + DType::F64 => run_fallback_matmul::(lhs, rhs), + DType::F32 => run_fallback_matmul::(lhs, rhs), + DType::F16 => run_fallback_matmul::(lhs, rhs), + DType::BF16 => run_fallback_matmul::(lhs, rhs), _ => todo!("Not yet supported"), } } @@ -88,16 +93,19 @@ impl MatmulFallbackFn for FallbackMatmul { impl ReduceFallbackFn for FallbackReduce { fn run( &self, - input_handle: CubeFusionHandle, + input: CubeFusionHandle, shape: &[usize], axis: usize, + inst: &ReduceInstruction, + d_o: &DType, ) -> CubeFusionHandle { - match input_handle.dtype { - burn_tensor::DType::F64 => run_fallback_reduce::(input_handle, shape, axis), - burn_tensor::DType::F32 => run_fallback_reduce::(input_handle, shape, axis), - burn_tensor::DType::F16 => run_fallback_reduce::(input_handle, shape, axis), - burn_tensor::DType::BF16 => run_fallback_reduce::(input_handle, shape, axis), - _ => todo!("Not yet supported"), + let d_i = input.dtype; + match inst { + ReduceInstruction::ArgMax => reduce_dtype::(input, shape, axis, &d_i, d_o), + ReduceInstruction::ArgMin => reduce_dtype::(input, shape, axis, &d_i, d_o), + ReduceInstruction::Mean => reduce_dtype::(input, shape, axis, &d_i, d_o), + ReduceInstruction::Prod => reduce_dtype::(input, shape, axis, &d_i, d_o), + ReduceInstruction::Sum => reduce_dtype::(input, shape, axis, &d_i, d_o), } } } @@ -135,7 +143,50 @@ fn run_fallback_matmul( } } -fn run_fallback_reduce( +fn reduce_dtype( + input_handle: CubeFusionHandle, + shape: &[usize], + axis: usize, + dtype_input: &DType, + dtype_output: &DType, +) -> CubeFusionHandle { + match dtype_input { + DType::F64 => reduce_dtype_output::(input_handle, shape, axis, dtype_output), + DType::F32 => reduce_dtype_output::(input_handle, shape, axis, dtype_output), + DType::F16 => reduce_dtype_output::(input_handle, shape, axis, dtype_output), + DType::BF16 => reduce_dtype_output::(input_handle, shape, axis, dtype_output), + DType::I64 => reduce_dtype_output::(input_handle, shape, axis, dtype_output), + DType::I32 => reduce_dtype_output::(input_handle, shape, axis, dtype_output), + DType::I16 => reduce_dtype_output::(input_handle, shape, axis, dtype_output), + DType::U64 => reduce_dtype_output::(input_handle, shape, axis, dtype_output), + DType::U32 => reduce_dtype_output::(input_handle, shape, axis, dtype_output), + DType::U16 => reduce_dtype_output::(input_handle, shape, axis, dtype_output), + _ => todo!("Not yet supported"), + } +} + +fn reduce_dtype_output( + input_handle: CubeFusionHandle, + shape: &[usize], + axis: usize, + dtype_output: &DType, +) -> CubeFusionHandle { + match dtype_output { + DType::F64 => reduce::(input_handle, shape, axis), + DType::F32 => reduce::(input_handle, shape, axis), + DType::F16 => reduce::(input_handle, shape, axis), + DType::BF16 => reduce::(input_handle, shape, axis), + DType::I64 => reduce::(input_handle, shape, axis), + DType::I32 => reduce::(input_handle, shape, axis), + DType::I16 => reduce::(input_handle, shape, axis), + DType::U64 => reduce::(input_handle, shape, axis), + DType::U32 => reduce::(input_handle, shape, axis), + DType::U16 => reduce::(input_handle, shape, axis), + _ => todo!("Not yet supported"), + } +} + +fn reduce( input_handle: CubeFusionHandle, shape: &[usize], axis: usize, @@ -146,13 +197,12 @@ fn run_fallback_reduce( dims: shape.to_vec(), }, ); - let out_tensor = - crate::kernel::reduce::reduce_dim::( - input_tensor, - axis, - crate::kernel::reduce::ReduceStrategy::default(), - ) - .unwrap(); + let out_tensor = crate::kernel::reduce::reduce_dim::( + input_tensor, + axis, + crate::kernel::reduce::ReduceStrategy::default(), + ) + .unwrap(); CubeFusionHandle { client: out_tensor.client, @@ -247,10 +297,7 @@ impl FusionBack type FullPrecisionBackend = CubeBackend; - fn cast_float( - tensor: burn_tensor::ops::FloatTensor, - dtype: burn_tensor::DType, - ) -> Self::Handle { + fn cast_float(tensor: burn_tensor::ops::FloatTensor, dtype: DType) -> Self::Handle { fn cast( tensor: CubeTensor, ) -> CubeFusionHandle { @@ -258,9 +305,9 @@ impl FusionBack } match dtype { - burn_tensor::DType::F32 => cast::(tensor), - burn_tensor::DType::F16 => cast::(tensor), - burn_tensor::DType::BF16 => cast::(tensor), + DType::F32 => cast::(tensor), + DType::F16 => cast::(tensor), + DType::BF16 => cast::(tensor), _ => panic!("Casting error: {dtype:?} unsupported."), } }