Skip to content

Commit

Permalink
Feat/improve fusion (#2773)
Browse files Browse the repository at this point in the history
* WIP

* WIP

* WIP testing

* Very wip

* WIP works better

* Fix vectorization

* Still debug

* Fix some problems

* Fix other broadcast issues

* Fix another bug, but still very wip

* WIP Works

* Cleanup

* Support broadcasted vectorization

* Cleanup

* Still some bugs

* Fix multi vectorization broadcasting fused

* Add fuse settings

* Fix broadcast issue

* Fix performance

* Some cleanup

* Big refactoring

* Add reshape optimization

* Cleanup

* Add some docs

* Update cubecl ref

* Clippy + Fmt

* Add vulkan in example

* WIP

* Fix test

* Cleanup

* Fix no std tests

* Better autotune

* Remove print

* Update crates/burn-jit/src/fusion/on_write/trace/output.rs

* Update crates/burn-jit/src/fusion/on_write/trace/plan.rs

---------

Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
  • Loading branch information
nathanielsimard and laggui authored Feb 6, 2025
1 parent 00422c1 commit 6015823
Show file tree
Hide file tree
Showing 40 changed files with 1,996 additions and 835 deletions.
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions backend-comparison/benches/matmul_fused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
}

fn execute(&self, (lhs, rhs, bias): Self::Args) {
let bias = bias.unsqueeze();
gelu(relu(lhs.matmul(rhs)) + bias);
let _output = gelu(relu(lhs.matmul(rhs)) + bias.unsqueeze());
}

fn prepare(&self) -> Self::Args {
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ extern crate alloc;
pub type TestBackend = burn_ndarray::NdArray<f32>;

#[cfg(all(test, feature = "test-tch"))]
/// Backend for test cases
pub type TestBackend = burn_tch::LibTorch<f32>;

#[cfg(all(test, feature = "test-wgpu"))]
/// Backend for test cases
pub type TestBackend = burn_wgpu::Wgpu;

#[cfg(all(test, feature = "test-cuda"))]
/// Backend for test cases
pub type TestBackend = burn_cuda::Cuda;

/// Backend for autodiff test cases
Expand Down
1 change: 0 additions & 1 deletion crates/burn-core/src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ impl<B: Backend> Linear<B> {

let weight = self.weight.val().unsqueeze();
let bias = self.bias.as_ref().map(|b| b.val().unsqueeze());

let output = input.matmul(weight);

match bias {
Expand Down
25 changes: 11 additions & 14 deletions crates/burn-core/src/nn/transformer/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ impl<B: Backend> TransformerDecoder<B> {

#[cfg(test)]
mod tests {
use burn_tensor::Device;

use super::*;
use crate::tensor::Distribution;
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};

#[test]
Expand All @@ -481,20 +482,16 @@ mod tests {
}

fn test_autoregressive(config: TransformerDecoderConfig) {
let device = Default::default();
let device: Device<TestBackend> = Default::default();
let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
let transformer = config.init(&device);

let memory = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
);
let target = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
);
let transformer = config.init::<TestBackend>(&device);

let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
.float()
.reshape([batch_size, seq_length, d_model]);
let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
.float()
.reshape([batch_size, seq_length, d_model]);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
.target_mask_attn(mask_attn);
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-fusion/src/ops/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use burn_tensor::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
HandleContainer, OperationDescription, PermuteOperationDescription,
RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
RepeatDimOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
SwapDimsDescription, UnaryOperationDescription,
},
Device, Shape,
};
Expand Down Expand Up @@ -171,7 +171,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
#[derive(new)]
struct ReshapeDimsOps<B: FusionBackend> {
desc: ReshapeDescription,
desc: UnaryOperationDescription,
_b: PhantomData<B>,
}

Expand All @@ -186,7 +186,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(shape.dims, DType::Bool);

let desc = ReshapeDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
#[derive(new)]
struct ReshapeDimsOps<B: FusionBackend> {
desc: ReshapeDescription,
desc: UnaryOperationDescription,
_b: PhantomData<B>,
}

Expand All @@ -656,7 +656,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let dtype = tensor.dtype;
let out = tensor.client.tensor_uninitialized(shape.dims, dtype);

let desc = ReshapeDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
#[derive(new)]
struct ReshapeDimsOps<B: FusionBackend> {
desc: ReshapeDescription,
desc: UnaryOperationDescription,
_b: PhantomData<B>,
}

Expand All @@ -110,7 +110,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(shape.dims, B::IntElem::dtype());

let desc = ReshapeDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
Expand Down
40 changes: 34 additions & 6 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,9 @@ pub struct Context<'a, H> {
pub scalar_u8: &'a Vec<u8>,
}

#[derive(Default)]
pub(crate) struct OperationConverter {
tensors_relative2global: HashMap<TensorId, TensorDescription>,
tensors_global2relative: HashMap<TensorId, TensorDescription>,
/// Only useful to create new shape ID.
/// You should use tensor descriptions to retrieve the proper shape.
shapes_global2relative: HashMap<usize, usize>,
scalar_f32: Vec<f32>,
scalar_f16: Vec<f16>,
Expand All @@ -59,6 +56,32 @@ pub(crate) struct OperationConverter {
scalar_u8: Vec<u8>,
}

impl Default for OperationConverter {
fn default() -> Self {
let mut val = Self {
tensors_relative2global: Default::default(),
tensors_global2relative: Default::default(),
shapes_global2relative: Default::default(),
scalar_f32: Default::default(),
scalar_f16: Default::default(),
scalar_bf16: Default::default(),
scalar_i64: Default::default(),
scalar_i32: Default::default(),
scalar_i16: Default::default(),
scalar_i8: Default::default(),
scalar_u64: Default::default(),
scalar_u32: Default::default(),
scalar_u16: Default::default(),
scalar_u8: Default::default(),
};

// global 1 is always shape id 0.
val.shapes_global2relative.insert(1, 0);

val
}
}

/// Fork of a [context](Context) which owns its data.
pub struct ContextOwned<H> {
tensors: HashMap<TensorId, TensorDescription>,
Expand Down Expand Up @@ -180,7 +203,11 @@ impl OperationConverter {
pub(crate) fn clear(&mut self) {
self.tensors_relative2global.clear();
self.tensors_global2relative.clear();

self.shapes_global2relative.clear();
// global 1 is always shape id 0.
self.shapes_global2relative.insert(1, 0);

self.scalar_f32.clear();
self.scalar_f16.clear();
self.scalar_bf16.clear();
Expand Down Expand Up @@ -1129,7 +1156,7 @@ impl RelativeOps for BaseOperationDescription {
BaseOperationDescription::ToDevice(desc.to_relative(converter))
}
BaseOperationDescription::Reshape(desc) => {
BaseOperationDescription::Reshape(ReshapeDescription {
BaseOperationDescription::Reshape(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
Expand Down Expand Up @@ -1246,6 +1273,7 @@ impl RelativeOps for TensorDescription {
// We never saw this dim value before, therefore we create a new ID.
let dim_id = converter.shapes_global2relative.len();
relative_shape.push(dim_id);

converter.shapes_global2relative.insert(*dim, dim_id);
}
}
Expand Down Expand Up @@ -1300,7 +1328,7 @@ mod tests {
tensor1_local,
TensorDescription {
id: TensorId::new(0),
shape: vec![0, 1, 2],
shape: vec![1, 2, 3],
status: TensorStatus::ReadOnly,
dtype: DType::F32
}
Expand All @@ -1309,7 +1337,7 @@ mod tests {
tensor2_local,
TensorDescription {
id: TensorId::new(1),
shape: vec![0, 3, 2],
shape: vec![1, 4, 3],
status: TensorStatus::ReadOnly,
dtype: DType::F32
}
Expand Down
13 changes: 11 additions & 2 deletions crates/burn-jit/src/fusion/elemwise/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use burn_fusion::OptimizationBuilder;

use crate::{
fusion::{
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision},
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings},
JitOptimization,
},
JitRuntime,
Expand All @@ -23,7 +23,16 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
let max_bindings = props.hardware_properties().max_bindings;

Self {
builder: FuseOnWriteBuilder::new(max_bindings, bool_precision),
builder: FuseOnWriteBuilder::new(
max_bindings,
bool_precision,
FuseSettings {
broadcast: true,
output_shape_updates: true,
mix_vectorization: true,
inplace: true,
},
),
device,
}
}
Expand Down
3 changes: 1 addition & 2 deletions crates/burn-jit/src/fusion/elemwise/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ impl<R: JitRuntime> TraceRunner<R> for ElemwiseRunner {
},
None => panic!("Invalid argument"),
};

let total_elem = shape.iter().product::<usize>() / *vectorization as usize;
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim);
Expand Down Expand Up @@ -141,7 +140,7 @@ fn elemwise_fuse(
let args = comptime![Sequence::<Arg>::new()];
let pos = ABSOLUTE_POS;

let length = match comptime![config.ref_layout] {
let length = match comptime![config.ref_layout.clone()] {
Arg::Input(index, precision, _) => match comptime![precision] {
ElemwisePrecision::F32 => inputs.t_f32.index(index).len(),
ElemwisePrecision::F16 => inputs.t_f16.index(index).len(),
Expand Down
Loading

0 comments on commit 6015823

Please sign in to comment.