Skip to content

Commit

Permalink
Fix fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Mar 6, 2025
1 parent c736943 commit f9b9e98
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 33 deletions.
12 changes: 9 additions & 3 deletions crates/burn-cubecl-fusion/src/reduce/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ pub trait ReduceFallbackFn<R: Runtime>: Send + Sync {
input_handle: CubeFusionHandle<R>,
shape: &[usize],
axis: usize,
inst: &ReduceInstruction,
dtype_out: &DType,
) -> CubeFusionHandle<R>;
}

Expand Down Expand Up @@ -217,9 +219,13 @@ impl<R: Runtime> ReduceOptimization<R> {
let input_handle = context
.handles
.get_handle(&input.id, &TensorStatus::ReadOnly);
let out_handle = self
.fallback
.run(input_handle, &input.shape, self.reduce.op.axis);
let out_handle = self.fallback.run(
input_handle,
&input.shape,
self.reduce.op.axis,
&self.reduce.inst,
&self.reduce.op.out.dtype,
);

(out_handle, out)
};
Expand Down
6 changes: 4 additions & 2 deletions crates/burn-cubecl-fusion/src/reduce/tune.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ fn tune_reduce_shared_plane<R: Runtime, BT: CubeElement>(
let context = input.context();

match context {
TuneContext::Original(context) => optimization.execute_fused_reduce::<BT>(context),
TuneContext::Original(context) => {
optimization.execute_fused_reduce_shared_plane::<BT>(context)
}
TuneContext::Fork(mut context_owned) => {
optimization.execute_fused_reduce::<BT>(&mut context_owned.as_context())
optimization.execute_fused_reduce_shared_plane::<BT>(&mut context_owned.as_context())
}
}
.map_err(|e| format!("{e:?}"))
Expand Down
103 changes: 75 additions & 28 deletions crates/burn-cubecl/src/fusion.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::element::CubeElement;
use crate::BoolElement;
use crate::{kernel, tensor::CubeTensor, CubeBackend, CubeRuntime, FloatElement, IntElement};

Expand All @@ -6,15 +7,19 @@ use burn_cubecl_fusion::matmul::builder::MatmulBuilder;
use burn_cubecl_fusion::matmul::optimization::MatmulOptimization;
use burn_cubecl_fusion::matmul::MatmulFallbackFn;
use burn_cubecl_fusion::reduce::builder::ReduceBuilder;
use burn_cubecl_fusion::reduce::optimization::{ReduceFallbackFn, ReduceOptimization};
use burn_cubecl_fusion::reduce::optimization::{
ReduceFallbackFn, ReduceInstruction, ReduceOptimization,
};
use burn_cubecl_fusion::CubeFusionHandle;
use burn_cubecl_fusion::{
elemwise::builder::ElementWiseBuilder, CubeOptimization, CubeOptimizationState,
};
use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
use burn_ir::{BackendIr, TensorHandle};
use burn_tensor::Shape;
use burn_tensor::{DType, Shape};
use core::marker::PhantomData;
use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum};
use cubecl::reduce::Reduce;
use half::{bf16, f16};
use std::sync::Arc;

Expand Down Expand Up @@ -76,10 +81,10 @@ impl<R: CubeRuntime> MatmulFallbackFn<R> for FallbackMatmul {
rhs: (CubeFusionHandle<R>, &[usize]),
) -> CubeFusionHandle<R> {
match lhs.0.dtype {
burn_tensor::DType::F64 => run_fallback_matmul::<R, f64>(lhs, rhs),
burn_tensor::DType::F32 => run_fallback_matmul::<R, f32>(lhs, rhs),
burn_tensor::DType::F16 => run_fallback_matmul::<R, f16>(lhs, rhs),
burn_tensor::DType::BF16 => run_fallback_matmul::<R, bf16>(lhs, rhs),
DType::F64 => run_fallback_matmul::<R, f64>(lhs, rhs),
DType::F32 => run_fallback_matmul::<R, f32>(lhs, rhs),
DType::F16 => run_fallback_matmul::<R, f16>(lhs, rhs),
DType::BF16 => run_fallback_matmul::<R, bf16>(lhs, rhs),
_ => todo!("Not yet supported"),
}
}
Expand All @@ -88,16 +93,19 @@ impl<R: CubeRuntime> MatmulFallbackFn<R> for FallbackMatmul {
impl<R: CubeRuntime> ReduceFallbackFn<R> for FallbackReduce {
fn run(
&self,
input_handle: CubeFusionHandle<R>,
input: CubeFusionHandle<R>,
shape: &[usize],
axis: usize,
inst: &ReduceInstruction,
d_o: &DType,
) -> CubeFusionHandle<R> {
match input_handle.dtype {
burn_tensor::DType::F64 => run_fallback_reduce::<R, f64>(input_handle, shape, axis),
burn_tensor::DType::F32 => run_fallback_reduce::<R, f32>(input_handle, shape, axis),
burn_tensor::DType::F16 => run_fallback_reduce::<R, f16>(input_handle, shape, axis),
burn_tensor::DType::BF16 => run_fallback_reduce::<R, bf16>(input_handle, shape, axis),
_ => todo!("Not yet supported"),
let d_i = input.dtype;
match inst {
ReduceInstruction::ArgMax => reduce_dtype::<R, ArgMax>(input, shape, axis, &d_i, d_o),
ReduceInstruction::ArgMin => reduce_dtype::<R, ArgMin>(input, shape, axis, &d_i, d_o),
ReduceInstruction::Mean => reduce_dtype::<R, Mean>(input, shape, axis, &d_i, d_o),
ReduceInstruction::Prod => reduce_dtype::<R, Prod>(input, shape, axis, &d_i, d_o),
ReduceInstruction::Sum => reduce_dtype::<R, Sum>(input, shape, axis, &d_i, d_o),
}
}
}
Expand Down Expand Up @@ -135,7 +143,50 @@ fn run_fallback_matmul<R: CubeRuntime, EG: FloatElement>(
}
}

fn run_fallback_reduce<R: CubeRuntime, EG: FloatElement>(
fn reduce_dtype<R: CubeRuntime, Red: Reduce>(
input_handle: CubeFusionHandle<R>,
shape: &[usize],
axis: usize,
dtype_input: &DType,
dtype_output: &DType,
) -> CubeFusionHandle<R> {
match dtype_input {
DType::F64 => reduce_dtype_output::<R, f64, Red>(input_handle, shape, axis, dtype_output),
DType::F32 => reduce_dtype_output::<R, f32, Red>(input_handle, shape, axis, dtype_output),
DType::F16 => reduce_dtype_output::<R, f16, Red>(input_handle, shape, axis, dtype_output),
DType::BF16 => reduce_dtype_output::<R, bf16, Red>(input_handle, shape, axis, dtype_output),
DType::I64 => reduce_dtype_output::<R, i64, Red>(input_handle, shape, axis, dtype_output),
DType::I32 => reduce_dtype_output::<R, i32, Red>(input_handle, shape, axis, dtype_output),
DType::I16 => reduce_dtype_output::<R, i16, Red>(input_handle, shape, axis, dtype_output),
DType::U64 => reduce_dtype_output::<R, u64, Red>(input_handle, shape, axis, dtype_output),
DType::U32 => reduce_dtype_output::<R, u32, Red>(input_handle, shape, axis, dtype_output),
DType::U16 => reduce_dtype_output::<R, u16, Red>(input_handle, shape, axis, dtype_output),
_ => todo!("Not yet supported"),
}
}

fn reduce_dtype_output<R: CubeRuntime, In: CubeElement, Red: Reduce>(
input_handle: CubeFusionHandle<R>,
shape: &[usize],
axis: usize,
dtype_output: &DType,
) -> CubeFusionHandle<R> {
match dtype_output {
DType::F64 => reduce::<R, In, f64, Red>(input_handle, shape, axis),
DType::F32 => reduce::<R, In, f32, Red>(input_handle, shape, axis),
DType::F16 => reduce::<R, In, f16, Red>(input_handle, shape, axis),
DType::BF16 => reduce::<R, In, bf16, Red>(input_handle, shape, axis),
DType::I64 => reduce::<R, In, i64, Red>(input_handle, shape, axis),
DType::I32 => reduce::<R, In, i32, Red>(input_handle, shape, axis),
DType::I16 => reduce::<R, In, i16, Red>(input_handle, shape, axis),
DType::U64 => reduce::<R, In, u64, Red>(input_handle, shape, axis),
DType::U32 => reduce::<R, In, u32, Red>(input_handle, shape, axis),
DType::U16 => reduce::<R, In, u16, Red>(input_handle, shape, axis),
_ => todo!("Not yet supported"),
}
}

fn reduce<R: CubeRuntime, In: CubeElement, Out: CubeElement, Red: Reduce>(
input_handle: CubeFusionHandle<R>,
shape: &[usize],
axis: usize,
Expand All @@ -146,13 +197,12 @@ fn run_fallback_reduce<R: CubeRuntime, EG: FloatElement>(
dims: shape.to_vec(),
},
);
let out_tensor =
crate::kernel::reduce::reduce_dim::<R, EG, EG, cubecl::reduce::instructions::Sum>(
input_tensor,
axis,
crate::kernel::reduce::ReduceStrategy::default(),
)
.unwrap();
let out_tensor = crate::kernel::reduce::reduce_dim::<R, In, Out, Red>(
input_tensor,
axis,
crate::kernel::reduce::ReduceStrategy::default(),
)
.unwrap();

CubeFusionHandle {
client: out_tensor.client,
Expand Down Expand Up @@ -247,20 +297,17 @@ impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBack

type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;

fn cast_float(
tensor: burn_tensor::ops::FloatTensor<Self>,
dtype: burn_tensor::DType,
) -> Self::Handle {
fn cast_float(tensor: burn_tensor::ops::FloatTensor<Self>, dtype: DType) -> Self::Handle {
fn cast<R: CubeRuntime, F: FloatElement, FTarget: FloatElement>(
tensor: CubeTensor<R>,
) -> CubeFusionHandle<R> {
CubeFusionHandle::from(kernel::cast::<R, F, FTarget>(tensor))
}

match dtype {
burn_tensor::DType::F32 => cast::<R, F, f32>(tensor),
burn_tensor::DType::F16 => cast::<R, F, f16>(tensor),
burn_tensor::DType::BF16 => cast::<R, F, bf16>(tensor),
DType::F32 => cast::<R, F, f32>(tensor),
DType::F16 => cast::<R, F, f16>(tensor),
DType::BF16 => cast::<R, F, bf16>(tensor),
_ => panic!("Casting error: {dtype:?} unsupported."),
}
}
Expand Down

0 comments on commit f9b9e98

Please sign in to comment.