From 909a782850e8e258da28441685fe68eedf642995 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 30 Apr 2024 09:43:25 -0400 Subject: [PATCH 01/25] Refactor handle container --- crates/burn-fusion/src/ops/binary.rs | 32 +-- crates/burn-fusion/src/ops/boolean.rs | 82 +++---- crates/burn-fusion/src/ops/float.rs | 211 +++++++++--------- crates/burn-fusion/src/ops/int.rs | 199 ++++++++--------- crates/burn-fusion/src/ops/module.rs | 176 +++++++-------- crates/burn-fusion/src/ops/unary.rs | 92 +++++--- crates/burn-fusion/src/server.rs | 24 +- crates/burn-fusion/src/stream/context.rs | 4 +- .../burn-fusion/src/stream/execution/base.rs | 10 +- crates/burn-fusion/src/stream/multi.rs | 10 +- crates/burn-tensor/src/repr/handle.rs | 59 +++-- 11 files changed, 465 insertions(+), 434 deletions(-) diff --git a/crates/burn-fusion/src/ops/binary.rs b/crates/burn-fusion/src/ops/binary.rs index e7035f674e..388c44cd4a 100644 --- a/crates/burn-fusion/src/ops/binary.rs +++ b/crates/burn-fusion/src/ops/binary.rs @@ -11,12 +11,12 @@ macro_rules! binary_float_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); - let rhs = handles.get_float_tensor(&self.desc.rhs); + fn execute(self: Box, handles: &mut HandleContainer<::Handle>) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); + let rhs = handles.get_float_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } }; @@ -35,12 +35,12 @@ macro_rules! binary_float_cmp_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); - let rhs = handles.get_float_tensor(&self.desc.rhs); + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); + let rhs = handles.get_float_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } }; @@ -59,12 +59,12 @@ macro_rules! binary_int_cmp_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_int_tensor::(&self.desc.lhs); - let rhs = handles.get_int_tensor(&self.desc.rhs); + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_int_tensor::(&self.desc.lhs); + let rhs = handles.get_int_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } }; @@ -93,12 +93,12 @@ macro_rules! binary_int_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_int_tensor::(&self.desc.lhs); - let rhs = handles.get_int_tensor(&self.desc.rhs); + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_int_tensor::(&self.desc.lhs); + let rhs = handles.get_int_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } }; diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index c641209046..f1ea9bd49f 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -63,10 +63,10 @@ impl BoolTensorOps for Fusion { } impl Operation for IntoIntOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_into_int(input); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -96,10 +96,10 @@ impl BoolTensorOps for Fusion { } impl Operation for IntoFloatOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_into_float(input); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -155,10 +155,10 @@ impl BoolTensorOps for Fusion { } impl Operation for ReshapeDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_reshape::(input, Shape::from(&self.desc.out.shape)); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -189,13 +189,13 @@ impl BoolTensorOps for Fusion { } impl Operation for SliceOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_bool_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_bool_tensor::(&self.desc.tensor); let output = B::bool_slice::(tensor, self.desc.ranges.clone().try_into().unwrap()); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -233,9 +233,9 @@ impl BoolTensorOps for Fusion { } impl Operation for SliceAssignOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_bool_tensor::(&self.desc.tensor); - let value = handles.get_bool_tensor::(&self.desc.value); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_bool_tensor::(&self.desc.tensor); + let value = handles.get_bool_tensor::(&self.desc.value); let output = B::bool_slice_assign::( tensor, @@ -243,7 +243,7 @@ impl BoolTensorOps for Fusion { value, ); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -278,17 +278,17 @@ impl BoolTensorOps for Fusion { } impl Operation for CatOps { - fn execute(self: Box, handles: &mut HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc .tensors .iter() - .map(|tensor| handles.get_bool_tensor(tensor)) + .map(|tensor| handles.get_bool_tensor::(tensor)) .collect(); let output = B::bool_cat::(tensors, self.desc.dim); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -330,11 +330,11 @@ impl BoolTensorOps for Fusion { } impl Operation for EqualOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_bool_tensor::(&self.desc.lhs); - let rhs = handles.get_bool_tensor(&self.desc.rhs); + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_bool_tensor::(&self.desc.lhs); + let rhs = handles.get_bool_tensor::(&self.desc.rhs); let output = B::bool_equal(lhs, rhs); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -365,10 +365,10 @@ impl BoolTensorOps for Fusion { } impl Operation for NotOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_not(input); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -400,10 +400,10 @@ impl BoolTensorOps for Fusion { } impl Operation for SwapDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -439,11 +439,11 @@ impl BoolTensorOps for Fusion { } impl Operation for PermuteDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); let output = B::bool_permute(input, axes); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -479,12 +479,12 @@ impl BoolTensorOps for Fusion { } impl Operation for ExpandOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); let output = B::bool_expand(input, shape.into()); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -517,10 +517,10 @@ impl BoolTensorOps for Fusion { } impl Operation for FlipOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_flip(input, self.desc.axes.as_slice()); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -553,12 +553,12 @@ impl BoolTensorOps for Fusion { } impl Operation for RepeatOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_bool_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_bool_tensor::(&self.desc.tensor); let output = B::bool_repeat::(tensor, self.desc.dim, self.desc.times); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index e282f11abd..e220cf79e3 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -7,20 +7,9 @@ use crate::{ stream::{execution::Operation, StreamId}, unary_float_ops, Fusion, FusionBackend, }; - use burn_tensor::{ ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, - repr::{ - BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, - ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - FloatOperationDescription, GatherOperationDescription, HandleContainer, - MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, - OperationDescription, PermuteOperationDescription, RandomOperationDescription, - ReduceDimWithIndicesDescription, RepeatOperationDescription, ReshapeDescription, - ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, - SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, - SwapDimsDescription, TensorDescription, UnaryOperationDescription, - }, + repr::*, Data, Device, Distribution, ElementConversion, Reader, Shape, }; use std::ops::Range; @@ -47,16 +36,17 @@ impl FloatTensorOps for Fusion { device: &Device, ) -> FloatTensor { #[derive(new)] - struct RandomOps { + struct RandomOps { desc: RandomOperationDescription, + device: Device, } - impl Operation for RandomOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for RandomOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.out.shape.clone()); let output: B::FloatTensorPrimitive = - B::float_random(shape, self.desc.distribution, &handles.device); - handles.register_float_tensor(&self.desc.out.id, output); + B::float_random(shape, self.desc.distribution, &self.device); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -72,7 +62,7 @@ impl FloatTensorOps for Fusion { client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Random(desc.clone())), - RandomOps::::new(desc), + RandomOps::::new(desc, device.clone()), ); out @@ -80,15 +70,16 @@ impl FloatTensorOps for Fusion { fn float_zeros(shape: Shape, device: &Device) -> FloatTensor { #[derive(new)] - struct ZerosOps { + struct ZerosOps { out: TensorDescription, + device: Device, } - impl Operation for ZerosOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for ZerosOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); - let output = B::float_zeros::(shape, &handles.device); - handles.register_float_tensor(&self.out.id, output); + let output = B::float_zeros::(shape, &self.device); + handles.register_float_tensor::(&self.out.id, output); } } @@ -101,7 +92,7 @@ impl FloatTensorOps for Fusion { client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Zeros(desc.clone())), - ZerosOps::::new(desc), + ZerosOps::::new(desc, device.clone()), ); out @@ -109,15 +100,16 @@ impl FloatTensorOps for Fusion { fn float_ones(shape: Shape, device: &Device) -> FloatTensor { #[derive(new)] - struct OnesOps { + struct OnesOps { out: TensorDescription, + device: Device, } - impl Operation for OnesOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for OnesOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); - let output = B::float_ones::(shape, &handles.device); - handles.register_float_tensor(&self.out.id, output); + let output = B::float_ones::(shape, &self.device); + handles.register_float_tensor::(&self.out.id, output); } } @@ -130,7 +122,7 @@ impl FloatTensorOps for Fusion { client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Ones(desc.clone())), - OnesOps::::new(desc), + OnesOps::::new(desc, device.clone()), ); out @@ -142,17 +134,18 @@ impl FloatTensorOps for Fusion { device: &Device, ) -> FloatTensor { #[derive(new)] - struct FullOps { + struct FullOps { out: TensorDescription, elem: f32, + device: Device, } - impl Operation for FullOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for FullOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); let output: B::FloatTensorPrimitive = - B::float_full(shape, self.elem.elem(), &handles.device); - handles.register_float_tensor(&self.out.id, output); + B::float_full(shape, self.elem.elem(), &self.device); + handles.register_float_tensor::(&self.out.id, output); } } @@ -165,7 +158,7 @@ impl FloatTensorOps for Fusion { client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Full(desc.clone())), - FullOps::::new(desc.0, desc.1), + FullOps::::new(desc.0, desc.1, device.clone()), ); out @@ -214,11 +207,11 @@ impl FloatTensorOps for Fusion { } impl Operation for IntoIntOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_into_int(input); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -309,11 +302,11 @@ impl FloatTensorOps for Fusion { } impl Operation for ClampOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_clamp(input, self.desc.min.elem(), self.desc.max.elem()); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -553,10 +546,10 @@ impl FloatTensorOps for Fusion { } impl Operation for SwapDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -593,10 +586,10 @@ impl FloatTensorOps for Fusion { } impl Operation for ReshapeDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_reshape::(input, Shape::from(&self.desc.out.shape)); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -628,12 +621,12 @@ impl FloatTensorOps for Fusion { } impl Operation for GatherOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::float_gather(self.desc.dim, tensor, indices); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -669,14 +662,14 @@ impl FloatTensorOps for Fusion { } impl Operation for ScatterOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); - let value = handles.get_float_tensor(&self.desc.value); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); + let value = handles.get_float_tensor::(&self.desc.value); let output = B::float_scatter(self.desc.dim, tensor, indices, value); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -714,13 +707,13 @@ impl FloatTensorOps for Fusion { } impl Operation for SelectOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::float_select(tensor, self.desc.dim, indices); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -756,14 +749,14 @@ impl FloatTensorOps for Fusion { } impl Operation for SelectAssignOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); - let value = handles.get_float_tensor(&self.desc.value); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); + let value = handles.get_float_tensor::(&self.desc.value); let output = B::float_select_assign(tensor, self.desc.dim, indices, value); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -801,13 +794,13 @@ impl FloatTensorOps for Fusion { } impl Operation for SliceOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_slice::(tensor, self.desc.ranges.clone().try_into().unwrap()); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; @@ -844,9 +837,9 @@ impl FloatTensorOps for Fusion { } impl Operation for SliceAssignOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let value = handles.get_float_tensor::(&self.desc.value); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let value = handles.get_float_tensor::(&self.desc.value); let output = B::float_slice_assign::( tensor, @@ -854,7 +847,7 @@ impl FloatTensorOps for Fusion { value, ); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -889,14 +882,14 @@ impl FloatTensorOps for Fusion { } impl Operation for MaskWhereOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let value = handles.get_float_tensor(&self.desc.value); - let mask = handles.get_bool_tensor(&self.desc.mask); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let value = handles.get_float_tensor::(&self.desc.value); + let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::float_mask_where(tensor, mask, value); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -934,13 +927,13 @@ impl FloatTensorOps for Fusion { } impl Operation for MaskFillOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let mask = handles.get_bool_tensor(&self.desc.mask); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::float_mask_fill(tensor, mask, self.desc.value.elem()); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -1223,7 +1216,7 @@ impl FloatTensorOps for Fusion { } fn float_sum(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(SumOps, B::float_sum); + unary_float_ops!(SumOps, B::float_sum, reduce); let stream = tensor.stream; let out = tensor.client.tensor_uninitialized(vec![1]); @@ -1267,7 +1260,7 @@ impl FloatTensorOps for Fusion { } fn float_mean(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MeanOps, B::float_mean); + unary_float_ops!(MeanOps, B::float_mean, reduce); let stream = tensor.stream; let out = tensor.client.tensor_uninitialized(vec![1]); @@ -1532,17 +1525,17 @@ impl FloatTensorOps for Fusion { } impl Operation for CatOps { - fn execute(self: Box, handles: &mut HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc .tensors .iter() - .map(|tensor| handles.get_float_tensor(tensor)) + .map(|tensor| handles.get_float_tensor::(tensor)) .collect(); let output = B::float_cat::(tensors, self.desc.dim); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -1609,12 +1602,12 @@ impl FloatTensorOps for Fusion { } impl Operation for RepeatOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_repeat::(tensor, self.desc.dim, self.desc.times); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -1664,7 +1657,7 @@ impl FloatTensorOps for Fusion { } fn float_max(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MaxOps, B::float_max); + unary_float_ops!(MaxOps, B::float_max, reduce); let stream = tensor.stream; let out = tensor.client.tensor_uninitialized(vec![1]); @@ -1717,12 +1710,12 @@ impl FloatTensorOps for Fusion { } impl Operation for MaxDimWithIndicesOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); let (output, indices) = B::float_max_dim_with_indices(tensor, self.desc.dim); - handles.register_float_tensor(&self.desc.out.id, output); - handles.register_int_tensor(&self.desc.out_indices.id, indices); + handles.register_float_tensor::(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } @@ -1751,7 +1744,7 @@ impl FloatTensorOps for Fusion { } fn float_min(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MinOps, B::float_min); + unary_float_ops!(MinOps, B::float_min, reduce); let stream = tensor.stream; let out = tensor.client.tensor_uninitialized(vec![1]); @@ -1804,12 +1797,12 @@ impl FloatTensorOps for Fusion { } impl Operation for MinDimWithIndicesOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); let (output, indices) = B::float_min_dim_with_indices(tensor, self.desc.dim); - handles.register_float_tensor(&self.desc.out.id, output); - handles.register_int_tensor(&self.desc.out_indices.id, indices); + handles.register_float_tensor::(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } @@ -1873,11 +1866,11 @@ impl FloatTensorOps for Fusion { } impl Operation for PermuteDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); let output = B::float_permute(input, axes); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -1913,12 +1906,12 @@ impl FloatTensorOps for Fusion { } impl Operation for ExpandOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); let output = B::float_expand(input, shape.into()); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -1951,10 +1944,10 @@ impl FloatTensorOps for Fusion { } impl Operation for FlipOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_flip(input, &self.desc.axes); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 2e691707bf..4ebd0b7fe0 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -9,17 +9,7 @@ use crate::{ }; use burn_tensor::{ ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, - repr::{ - self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, - ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - GatherOperationDescription, HandleContainer, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatOperationDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, - TensorDescription, UnaryOperationDescription, - }, + repr::{self, *}, Data, Device, Distribution, ElementConversion, Reader, Shape, }; use core::ops::Range; @@ -87,10 +77,10 @@ impl IntTensorOps for Fusion { } impl Operation for ReshapeDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_reshape::(input, Shape::from(&self.desc.out.shape)); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -121,13 +111,13 @@ impl IntTensorOps for Fusion { } impl Operation for SliceOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_slice::(tensor, self.desc.ranges.clone().try_into().unwrap()); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -165,9 +155,9 @@ impl IntTensorOps for Fusion { } impl Operation for SliceAssignOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let value = handles.get_int_tensor::(&self.desc.value); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let value = handles.get_int_tensor::(&self.desc.value); let output = B::int_slice_assign::( tensor, @@ -175,7 +165,7 @@ impl IntTensorOps for Fusion { value, ); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -209,14 +199,14 @@ impl IntTensorOps for Fusion { } impl Operation for MaskWhereOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let value = handles.get_int_tensor(&self.desc.value); - let mask = handles.get_bool_tensor(&self.desc.mask); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let value = handles.get_int_tensor::(&self.desc.value); + let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::int_mask_where(tensor, mask, value); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -252,13 +242,13 @@ impl IntTensorOps for Fusion { } impl Operation for MaskFillOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let mask = handles.get_bool_tensor(&self.desc.mask); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::int_mask_fill(tensor, mask, self.desc.value.elem()); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -292,12 +282,12 @@ impl IntTensorOps for Fusion { } impl Operation for GatherOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::int_gather(self.desc.dim, tensor, indices); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -332,14 +322,14 @@ impl IntTensorOps for Fusion { } impl Operation for ScatterOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); - let value = handles.get_int_tensor(&self.desc.value); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); + let value = handles.get_int_tensor::(&self.desc.value); let output = B::int_scatter(self.desc.dim, tensor, indices, value); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -375,13 +365,13 @@ impl IntTensorOps for Fusion { } impl Operation for SelectOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::int_select(tensor, self.desc.dim, indices); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -417,14 +407,14 @@ impl IntTensorOps for Fusion { } impl Operation for SelectAssignOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); - let value = handles.get_int_tensor(&self.desc.value); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); + let value = handles.get_int_tensor::(&self.desc.value); let output = B::int_select_assign(tensor, self.desc.dim, indices, value); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -458,17 +448,17 @@ impl IntTensorOps for Fusion { } impl Operation for CatOps { - fn execute(self: Box, handles: &mut HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc .tensors .iter() - .map(|tensor| handles.get_int_tensor(tensor)) + .map(|tensor| handles.get_int_tensor::(tensor)) .collect(); let output = B::int_cat::(tensors, self.desc.dim); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -983,15 +973,16 @@ impl IntTensorOps for Fusion { fn int_zeros(shape: Shape, device: &Device) -> IntTensor { #[derive(new)] - struct ZerosOps { + struct ZerosOps { desc: TensorDescription, + device: Device, } - impl Operation for ZerosOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for ZerosOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.shape.clone()); - let output = B::int_zeros::(shape, &handles.device); - handles.register_int_tensor(&self.desc.id, output); + let output = B::int_zeros::(shape, &self.device); + handles.register_int_tensor::(&self.desc.id, output); } } @@ -1003,7 +994,7 @@ impl IntTensorOps for Fusion { client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Zeros(desc.clone())), - ZerosOps::::new(desc), + ZerosOps::::new(desc, device.clone()), ); out @@ -1011,15 +1002,16 @@ impl IntTensorOps for Fusion { fn int_ones(shape: Shape, device: &Device) -> IntTensor { #[derive(new)] - struct OnesOps { + struct OnesOps { desc: TensorDescription, + device: Device, } - impl Operation for OnesOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for OnesOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.shape.clone()); - let output = B::int_ones::(shape, &handles.device); - handles.register_int_tensor(&self.desc.id, output); + let output = B::int_ones::(shape, &self.device); + handles.register_int_tensor::(&self.desc.id, output); } } @@ -1032,14 +1024,14 @@ impl IntTensorOps for Fusion { client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Ones(desc.clone())), - OnesOps::::new(desc), + OnesOps::::new(desc, device.clone()), ); out } fn int_sum(tensor: IntTensor) -> IntTensor { - unary_int_ops!(SumOps, B::int_sum); + unary_int_ops!(SumOps, B::int_sum, reduce); let stream = tensor.stream; let out = tensor.client.tensor_uninitialized(vec![1]); @@ -1080,7 +1072,7 @@ impl IntTensorOps for Fusion { } fn int_prod(tensor: IntTensor) -> IntTensor { - unary_int_ops!(ProdOps, B::int_prod); + unary_int_ops!(ProdOps, B::int_prod, reduce); let stream = tensor.stream; let out = tensor.client.tensor_uninitialized(vec![1]); @@ -1121,7 +1113,7 @@ impl IntTensorOps for Fusion { } fn int_mean(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MeanOps, B::int_mean); + unary_int_ops!(MeanOps, B::int_mean, reduce); let stream = tensor.stream; let out = tensor.client.tensor_uninitialized(vec![1]); @@ -1216,11 +1208,11 @@ impl IntTensorOps for Fusion { } impl Operation for ClampOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_clamp(input, self.desc.min.elem(), self.desc.max.elem()); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -1267,10 +1259,10 @@ impl IntTensorOps for Fusion { } impl Operation for IntoFloatOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_into_float(input); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -1300,10 +1292,10 @@ impl IntTensorOps for Fusion { } impl Operation for SwapDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -1330,7 +1322,7 @@ impl IntTensorOps for Fusion { } fn int_max(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MaxOps, B::int_max); + unary_int_ops!(MaxOps, B::int_max, reduce); let stream = tensor.stream; let out = tensor.client.tensor_uninitialized(vec![1]); @@ -1380,12 +1372,12 @@ impl IntTensorOps for Fusion { } impl Operation for MaxDimWithIndicesOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim); - handles.register_int_tensor(&self.desc.out.id, output); - handles.register_int_tensor(&self.desc.out_indices.id, indices); + handles.register_int_tensor::(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } @@ -1413,7 +1405,7 @@ impl IntTensorOps for Fusion { } fn int_min(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MinOps, B::int_min); + unary_int_ops!(MinOps, B::int_min, reduce); let stream = tensor.stream; let out = tensor.client.tensor_uninitialized(vec![1]); @@ -1463,12 +1455,12 @@ impl IntTensorOps for Fusion { } impl Operation for MinDimWithIndicesOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim); - handles.register_int_tensor(&self.desc.out.id, output); - handles.register_int_tensor(&self.desc.out_indices.id, indices); + handles.register_int_tensor::(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } @@ -1501,16 +1493,17 @@ impl IntTensorOps for Fusion { device: &Device, ) -> IntTensor { #[derive(new)] - struct IntRandomOps { + struct IntRandomOps { desc: RandomOperationDescription, + device: Device, } - impl Operation for IntRandomOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for IntRandomOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.out.shape.clone()); let output: B::IntTensorPrimitive = - B::int_random(shape, self.desc.distribution, &handles.device); - handles.register_int_tensor(&self.desc.out.id, output); + B::int_random(shape, self.desc.distribution, &self.device); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -1526,7 +1519,7 @@ impl IntTensorOps for Fusion { client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::IntRandom(desc.clone())), - IntRandomOps::::new(desc), + IntRandomOps::::new(desc, device.clone()), ); out @@ -1542,11 +1535,11 @@ impl IntTensorOps for Fusion { } impl Operation for PermuteDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); let output = B::int_permute(input, axes); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -1582,11 +1575,11 @@ impl IntTensorOps for Fusion { } impl Operation for ExpandOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); let output = B::bool_expand(input, shape.into()); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -1616,11 +1609,11 @@ impl IntTensorOps for Fusion { } impl Operation for FlipDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let axes = &self.desc.axes; let output = B::int_flip(input, axes); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -1654,12 +1647,12 @@ impl IntTensorOps for Fusion { } impl Operation for RepeatOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_repeat::(tensor, self.desc.dim, self.desc.times); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } diff --git a/crates/burn-fusion/src/ops/module.rs b/crates/burn-fusion/src/ops/module.rs index a2543559c8..a3150c97a7 100644 --- a/crates/burn-fusion/src/ops/module.rs +++ b/crates/burn-fusion/src/ops/module.rs @@ -9,17 +9,7 @@ use burn_tensor::{ MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }, - repr::{ - AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription, - AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription, - AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription, - AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, - ConvTranspose2dDescription, HandleContainer, InterpolateBackwardDescription, - InterpolateDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, - MaxPool1dWithIndicesDescription, MaxPool2dDescription, - MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, - ModuleOperationDescription, OperationDescription, - }, + repr::*, }; macro_rules! make_ops { @@ -30,7 +20,7 @@ macro_rules! make_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { #[allow(clippy::redundant_closure_call)] $fn(self.desc, handles) } @@ -48,15 +38,15 @@ impl ModuleOps> for Fusion { make_ops!( Conv1dOps, Conv1dDescription, - |desc: Conv1dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&desc.x); - let weight = handles.get_float_tensor(&desc.weight); + |desc: Conv1dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&desc.x); + let weight = handles.get_float_tensor::(&desc.weight); let bias = desc .bias .as_ref() - .map(|bias| handles.get_float_tensor(bias)); + .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv1d(x, weight, bias, desc.options.into()); - handles.register_float_tensor(&desc.out.id, output); + handles.register_float_tensor::(&desc.out.id, output); } ); @@ -104,17 +94,17 @@ impl ModuleOps> for Fusion { make_ops!( Conv2dOps, Conv2dDescription, - |args: Conv2dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); + |args: Conv2dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let weight = handles.get_float_tensor::(&args.weight); let bias = args .bias .as_ref() - .map(|bias| handles.get_float_tensor(bias)); + .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv2d(x, weight, bias, args.options.clone().into()); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -169,17 +159,17 @@ impl ModuleOps> for Fusion { make_ops!( ConvTranspose1dOps, ConvTranspose1dDescription, - |args: ConvTranspose1dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); + |args: ConvTranspose1dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let weight = handles.get_float_tensor::(&args.weight); let bias = args .bias .as_ref() - .map(|bias| handles.get_float_tensor(bias)); + .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv_transpose1d(x, weight, bias, args.options.clone().into()); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -228,17 +218,17 @@ impl ModuleOps> for Fusion { make_ops!( ConvTranspose2dOps, ConvTranspose2dDescription, - |args: ConvTranspose2dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); + |args: ConvTranspose2dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let weight = handles.get_float_tensor::(&args.weight); let bias = args .bias .as_ref() - .map(|bias| handles.get_float_tensor(bias)); + .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv_transpose2d(x, weight, bias, args.options.clone().into()); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -296,8 +286,8 @@ impl ModuleOps> for Fusion { make_ops!( AvgPool1dOps, AvgPool1dDescription, - |args: AvgPool1dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: AvgPool1dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::avg_pool1d( x, args.kernel_size, @@ -306,7 +296,7 @@ impl ModuleOps> for Fusion { args.count_include_pad, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -342,8 +332,8 @@ impl ModuleOps> for Fusion { make_ops!( AvgPool2dOps, AvgPool2dDescription, - |args: AvgPool2dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: AvgPool2dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::avg_pool2d( x, args.kernel_size, @@ -352,7 +342,7 @@ impl ModuleOps> for Fusion { args.count_include_pad, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -393,9 +383,9 @@ impl ModuleOps> for Fusion { make_ops!( AvgPool1dBackwardOps, AvgPool1dBackwardDescription, - |args: AvgPool1dBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); + |args: AvgPool1dBackwardDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); let output = B::avg_pool1d_backward( x, grad, @@ -405,7 +395,7 @@ impl ModuleOps> for Fusion { args.count_include_pad, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -444,9 +434,9 @@ impl ModuleOps> for Fusion { make_ops!( AvgPool2dBackwardOps, AvgPool2dBackwardDescription, - |args: AvgPool2dBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); + |args: AvgPool2dBackwardDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); let output = B::avg_pool2d_backward( x, grad, @@ -456,7 +446,7 @@ impl ModuleOps> for Fusion { args.count_include_pad, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -494,8 +484,8 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool1dOps, MaxPool1dDescription, - |args: MaxPool1dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: MaxPool1dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::max_pool1d( x, args.kernel_size, @@ -504,7 +494,7 @@ impl ModuleOps> for Fusion { args.dilation, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -541,8 +531,8 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool2dOps, MaxPool2dDescription, - |args: MaxPool2dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: MaxPool2dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::max_pool2d( x, args.kernel_size, @@ -551,7 +541,7 @@ impl ModuleOps> for Fusion { args.dilation, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -601,8 +591,8 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool1dWithIndicesOps, MaxPool1dWithIndicesDescription, - |args: MaxPool1dWithIndicesDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: MaxPool1dWithIndicesDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::max_pool1d_with_indices( x, args.kernel_size, @@ -611,8 +601,8 @@ impl ModuleOps> for Fusion { args.dilation, ); - handles.register_float_tensor(&args.out.id, output.output); - handles.register_int_tensor(&args.out_indices.id, output.indices); + handles.register_float_tensor::(&args.out.id, output.output); + handles.register_int_tensor::(&args.out_indices.id, output.indices); } ); @@ -652,8 +642,8 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool2dWithIndicesOps, MaxPool2dWithIndicesDescription, - |args: MaxPool2dWithIndicesDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: MaxPool2dWithIndicesDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::max_pool2d_with_indices( x, args.kernel_size, @@ -662,8 +652,8 @@ impl ModuleOps> for Fusion { args.dilation, ); - handles.register_float_tensor(&args.out.id, output.output); - handles.register_int_tensor(&args.out_indices.id, output.indices); + handles.register_float_tensor::(&args.out.id, output.output); + handles.register_int_tensor::(&args.out_indices.id, output.indices); } ); @@ -719,10 +709,11 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool1dWithIndicesBackwardOps, MaxPool1dWithIndicesBackwardDescription, - |args: MaxPool1dWithIndicesBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let indices = handles.get_int_tensor(&args.indices); + |args: MaxPool1dWithIndicesBackwardDescription, + handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); + let indices = handles.get_int_tensor::(&args.indices); let output = B::max_pool1d_with_indices_backward( x, args.kernel_size, @@ -733,7 +724,7 @@ impl ModuleOps> for Fusion { indices, ); - handles.register_float_tensor(&args.out.id, output.x_grad); + handles.register_float_tensor::(&args.out.id, output.x_grad); } ); @@ -775,10 +766,11 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool2dWithIndicesBackwardOps, MaxPool2dWithIndicesBackwardDescription, - |args: MaxPool2dWithIndicesBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let indices = handles.get_int_tensor(&args.indices); + |args: MaxPool2dWithIndicesBackwardDescription, + handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); + let indices = handles.get_int_tensor::(&args.indices); let output = B::max_pool2d_with_indices_backward( x, args.kernel_size, @@ -789,7 +781,7 @@ impl ModuleOps> for Fusion { indices, ); - handles.register_float_tensor(&args.out.id, output.x_grad); + handles.register_float_tensor::(&args.out.id, output.x_grad); } ); @@ -823,11 +815,11 @@ impl ModuleOps> for Fusion { make_ops!( AdaptiveAvgPool1dOps, AdaptiveAvgPool1dDescription, - |args: AdaptiveAvgPool1dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: AdaptiveAvgPool1dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::adaptive_avg_pool1d(x, args.output_size); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -858,11 +850,11 @@ impl ModuleOps> for Fusion { make_ops!( AdaptiveAvgPool2dOps, AdaptiveAvgPool2dDescription, - |args: AdaptiveAvgPool2dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: AdaptiveAvgPool2dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::adaptive_avg_pool2d(x, args.output_size); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -893,12 +885,13 @@ impl ModuleOps> for Fusion { make_ops!( AdaptiveAvgPool1dBackwardOps, AdaptiveAvgPool1dBackwardDescription, - |args: AdaptiveAvgPool1dBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); + |args: AdaptiveAvgPool1dBackwardDescription, + handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); let output = B::adaptive_avg_pool1d_backward(x, grad); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -929,12 +922,13 @@ impl ModuleOps> for Fusion { make_ops!( AdaptiveAvgPool2dBackwardOps, AdaptiveAvgPool2dBackwardDescription, - |args: AdaptiveAvgPool2dBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); + |args: AdaptiveAvgPool2dBackwardDescription, + handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); let output = B::adaptive_avg_pool2d_backward(x, grad); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -966,10 +960,10 @@ impl ModuleOps> for Fusion { make_ops!( InterpolateOps, InterpolateDescription, - |args: InterpolateDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: InterpolateDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::interpolate(x, args.output_size, args.options.clone().into()); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -1002,13 +996,13 @@ impl ModuleOps> for Fusion { make_ops!( InterpolateBackwardOps, InterpolateBackwardDescription, - |args: InterpolateBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); + |args: InterpolateBackwardDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); let output = B::interpolate_backward(x, grad, args.output_size, args.options.clone().into()); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); diff --git a/crates/burn-fusion/src/ops/unary.rs b/crates/burn-fusion/src/ops/unary.rs index d6f0833c64..921f2503da 100644 --- a/crates/burn-fusion/src/ops/unary.rs +++ b/crates/burn-fusion/src/ops/unary.rs @@ -18,11 +18,11 @@ macro_rules! scalar_float_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } }; @@ -38,11 +38,11 @@ macro_rules! scalar_float_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } }; @@ -62,11 +62,11 @@ macro_rules! scalar_float2int_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs.clone()); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } }; @@ -85,11 +85,30 @@ macro_rules! unary_float_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let output = $ops(input); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); + } + } + }; + ( + $name:ident, + $ops:expr, + reduce + ) => { + #[derive(new)] + struct $name { + desc: UnaryOperationDescription, + } + + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); + let output = $ops(input); + + handles.register_float_tensor::(&self.desc.out.id, output); } } }; @@ -108,11 +127,30 @@ macro_rules! unary_int_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); + let output = $ops(input); + + handles.register_int_tensor::(&self.desc.out.id, output); + } + } + }; + ( + $name:ident, + $ops:expr, + reduce + ) => { + #[derive(new)] + struct $name { + desc: UnaryOperationDescription, + } + + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let output = $ops(input); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } }; @@ -131,11 +169,11 @@ macro_rules! scalar_float_cmp_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } }; @@ -154,11 +192,11 @@ macro_rules! scalar_int_cmp_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_int_tensor::(&self.desc.lhs); + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } }; @@ -184,11 +222,11 @@ macro_rules! scalar_int_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_int_tensor::(&self.desc.lhs); + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } }; @@ -204,11 +242,11 @@ macro_rules! scalar_int_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_int_tensor::(&self.desc.lhs); + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } }; diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index 4c6ec62fc7..3e6b5d911f 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -13,7 +13,7 @@ where B: FusionBackend, { streams: MultiStream, - pub(crate) handles: HandleContainer, + pub(crate) handles: HandleContainer, pub device: B::Device, } @@ -24,7 +24,7 @@ where pub fn new(device: B::Device) -> Self { Self { streams: MultiStream::new(device.clone()), - handles: HandleContainer::new(device.clone()), + handles: HandleContainer::new(), device, } } @@ -56,7 +56,7 @@ where // The underlying backend can still be async. self.drain_stream(id); - let tensor = self.handles.get_float_tensor(&tensor); + let tensor = self.handles.get_float_tensor::(&tensor); B::float_into_data(tensor) } @@ -69,7 +69,7 @@ where // The underlying backend can still be async. self.drain_stream(id); - let tensor = self.handles.get_int_tensor(&tensor); + let tensor = self.handles.get_int_tensor::(&tensor); B::int_into_data(tensor) } @@ -82,7 +82,7 @@ where // The underlying backend can still be async. self.drain_stream(id); - let tensor = self.handles.get_bool_tensor(&tensor); + let tensor = self.handles.get_bool_tensor::(&tensor); B::bool_into_data(tensor) } @@ -92,45 +92,47 @@ where device: &B::Device, server_device: &mut Self, ) -> Arc { - let tensor = self.handles.get_float_tensor::(tensor); + let tensor = self.handles.get_float_tensor::(tensor); let tensor = B::float_to_device(tensor, device); let id = server_device.create_empty_handle(); server_device .handles - .register_float_tensor(&id, tensor.clone()); + .register_float_tensor::(&id, tensor.clone()); id } + pub fn change_server_int( &mut self, tensor: &TensorDescription, device: &B::Device, server_device: &mut Self, ) -> Arc { - let tensor = self.handles.get_int_tensor::(tensor); + let tensor = self.handles.get_int_tensor::(tensor); let tensor = B::int_to_device(tensor, device); let id = server_device.create_empty_handle(); server_device .handles - .register_int_tensor(&id, tensor.clone()); + .register_int_tensor::(&id, tensor.clone()); id } + pub fn change_server_bool( &mut self, tensor: &TensorDescription, device: &B::Device, server_device: &mut Self, ) -> Arc { - let tensor = self.handles.get_bool_tensor::(tensor); + let tensor = self.handles.get_bool_tensor::(tensor); let tensor = B::bool_to_device(tensor, device); let id = server_device.create_empty_handle(); server_device .handles - .register_bool_tensor(&id, tensor.clone()); + .register_bool_tensor::(&id, tensor.clone()); id } diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 11441a40a8..9b15b36d07 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -13,7 +13,7 @@ pub struct Context<'a, B: FusionBackend> { /// The tensor mapping where local tensor id points to the updated tensor description. pub tensors: &'a HashMap, /// Handle container to retrieve tensors based on their description. - pub handles: &'a mut HandleContainer, + pub handles: &'a mut HandleContainer, /// Float scalars found in the graph in the order they appeared. pub scalar_floats: &'a Vec, /// Int scalars found in the graph in the order they appeared. @@ -44,7 +44,7 @@ trait RelativeOpsScalar { impl OperationConverter { pub(crate) fn context<'a, B: FusionBackend>( &'a self, - handles: &'a mut HandleContainer, + handles: &'a mut HandleContainer, ) -> Context<'a, B> { Context { handles, diff --git a/crates/burn-fusion/src/stream/execution/base.rs b/crates/burn-fusion/src/stream/execution/base.rs index 7b2dce0f50..91f6b322cc 100644 --- a/crates/burn-fusion/src/stream/execution/base.rs +++ b/crates/burn-fusion/src/stream/execution/base.rs @@ -18,7 +18,7 @@ pub(crate) enum ExecutionMode { /// General trait to abstract how a single operation is executed. pub trait Operation: Send + Sync { /// Execute the operation. - fn execute(self: Box, handles: &mut HandleContainer); + fn execute(self: Box, handles: &mut HandleContainer); } impl OperationQueue { @@ -26,7 +26,7 @@ impl OperationQueue { pub(crate) fn execute( &mut self, id: ExecutionPlanId, - handles: &mut HandleContainer, + handles: &mut HandleContainer, store: &mut ExecutionPlanStore, ) { match &mut store.get_mut_unchecked(id).strategy { @@ -39,7 +39,7 @@ impl OperationQueue { fn execute_optimization( &mut self, - handles: &mut HandleContainer, + handles: &mut HandleContainer, optimization: &mut B::Optimization, ) { let num_drained = optimization.len(); @@ -51,7 +51,7 @@ impl OperationQueue { self.operations.drain(0..num_drained); } - fn execute_operations(&mut self, handles: &mut HandleContainer) { + fn execute_operations(&mut self, handles: &mut HandleContainer) { let num_drained = self.operations.len(); for operation in self.operations.drain(0..num_drained) { @@ -61,7 +61,7 @@ impl OperationQueue { self.drain_stream(num_drained, handles); } - fn drain_stream(&mut self, num_drained: usize, handles: &mut HandleContainer) { + fn drain_stream(&mut self, num_drained: usize, handles: &mut HandleContainer) { self.global[0..num_drained] .iter() .flat_map(|desc| desc.nodes()) diff --git a/crates/burn-fusion/src/stream/multi.rs b/crates/burn-fusion/src/stream/multi.rs index 035b6ed280..c12442775f 100644 --- a/crates/burn-fusion/src/stream/multi.rs +++ b/crates/burn-fusion/src/stream/multi.rs @@ -30,7 +30,7 @@ impl MultiStream { streams: Vec, desc: OperationDescription, operation: Box>, - handles: &mut HandleContainer, + handles: &mut HandleContainer, ) { let id = self.maybe_drain(streams, handles); @@ -65,7 +65,7 @@ impl MultiStream { } /// Drain the streams. - pub fn drain(&mut self, handles: &mut HandleContainer, id: StreamId) { + pub fn drain(&mut self, handles: &mut HandleContainer, id: StreamId) { if let Some(mut stream) = self.streams.remove(&id) { stream.processor.process( Segment::new(&mut stream.queue, handles), @@ -80,7 +80,7 @@ impl MultiStream { fn maybe_drain( &mut self, streams: Vec, - handles: &mut HandleContainer, + handles: &mut HandleContainer, ) -> StreamId { let streams = Self::remove_duplicate(streams); let current = StreamId::current(); @@ -113,7 +113,7 @@ impl MultiStream { output } - fn free_orphans(&self, handles: &mut HandleContainer) { + fn free_orphans(&self, handles: &mut HandleContainer) { let nodes = self .streams .values() @@ -134,7 +134,7 @@ struct Stream { #[derive(new)] struct Segment<'a, B: FusionBackend> { queue: &'a mut OperationQueue, - handles: &'a mut HandleContainer, + handles: &'a mut HandleContainer, } impl<'i, B: FusionBackend> StreamSegment for Segment<'i, B> { diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index 89cf6b8fc2..b49467d126 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -1,5 +1,4 @@ use crate::{ - backend::Backend, repr::{ backend::ReprBackend, tensor::{TensorDescription, TensorId, TensorStatus}, @@ -11,36 +10,33 @@ use std::{collections::HashMap, sync::Arc}; /// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources /// are used optimally. #[derive(Default)] -pub struct HandleContainer { - handles: HashMap>, +pub struct HandleContainer { + handles: HashMap>, counter: u64, /// Handle candidates to be freed. pub handles_orphan: Vec, - /// The device on which all tensors are held. - pub device: B::Device, } /// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state -pub enum Handle { +pub enum Handle { /// No [tensor handle](ReprBackend::Handle) has been created yet NotInit, /// A [tensor handle](ReprBackend::Handle) has been created - Existing(B::Handle), + Existing(H), } -impl HandleContainer { +impl HandleContainer { /// Create a new HandleContainer - pub fn new(device_handle: B::Device) -> Self { + pub fn new() -> Self { Self { handles: HashMap::new(), handles_orphan: Vec::new(), counter: 0, - device: device_handle.clone(), } } /// Register a handle for the given [tensor id](TensorId). - pub fn register_handle(&mut self, id: TensorId, handle: B::Handle) { + pub fn register_handle(&mut self, id: TensorId, handle: H) { self.handles.insert(id, Handle::Existing(handle)); } @@ -51,7 +47,7 @@ impl HandleContainer { /// /// Make sure the status corresponds to the operation you want to execute the handle on, /// otherwise you might remove a tensor handle that will be required in the future. - pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> B::Handle { + pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H { let (id, handle) = self .handles .remove_entry(id) @@ -72,10 +68,13 @@ impl HandleContainer { /// Get the [float tensor](Backend::FloatTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). - pub fn get_float_tensor( + pub fn get_float_tensor( &mut self, tensor: &TensorDescription, - ) -> B::FloatTensorPrimitive { + ) -> B::FloatTensorPrimitive + where + B: ReprBackend, + { B::float_tensor::( self.get_handle(&tensor.id, &tensor.status), Shape::from(&tensor.shape), @@ -84,10 +83,13 @@ impl HandleContainer { /// Get the [int tensor](Backend::IntTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). - pub fn get_int_tensor( + pub fn get_int_tensor( &mut self, tensor: &TensorDescription, - ) -> B::IntTensorPrimitive { + ) -> B::IntTensorPrimitive + where + B: ReprBackend, + { B::int_tensor::( self.get_handle(&tensor.id, &tensor.status), Shape::from(&tensor.shape), @@ -96,10 +98,13 @@ impl HandleContainer { /// Get the [bool tensor](Backend::BoolTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). - pub fn get_bool_tensor( + pub fn get_bool_tensor( &mut self, tensor: &TensorDescription, - ) -> B::BoolTensorPrimitive { + ) -> B::BoolTensorPrimitive + where + B: ReprBackend, + { B::bool_tensor::( self.get_handle(&tensor.id, &tensor.status), Shape::from(&tensor.shape), @@ -107,31 +112,37 @@ impl HandleContainer { } /// Register a new [float tensor](Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_float_tensor( + pub fn register_float_tensor( &mut self, id: &TensorId, tensor: B::FloatTensorPrimitive, - ) { + ) where + B: ReprBackend, + { let handle = B::float_tensor_handle::(tensor); self.handles.insert(*id, Handle::Existing(handle)); } /// Register a new [int tensor](Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_int_tensor( + pub fn register_int_tensor( &mut self, id: &TensorId, tensor: B::IntTensorPrimitive, - ) { + ) where + B: ReprBackend, + { let handle = B::int_tensor_handle::(tensor); self.handles.insert(*id, Handle::Existing(handle)); } /// Register a new [bool tensor](Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_bool_tensor( + pub fn register_bool_tensor( &mut self, id: &TensorId, tensor: B::BoolTensorPrimitive, - ) { + ) where + B: ReprBackend, + { let handle = B::bool_tensor_handle::(tensor); self.handles.insert(*id, Handle::Existing(handle)); } From 9ae3cc8decdcbeec288ea7ed31679de8bbf0956b Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 30 Apr 2024 10:26:09 -0400 Subject: [PATCH 02/25] Fix docs --- crates/burn-tensor/src/repr/handle.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index b49467d126..2a745ec6b2 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -66,7 +66,7 @@ impl HandleContainer { } } - /// Get the [float tensor](Backend::FloatTensorPrimitive) corresponding to the + /// Get the [float tensor](ReprBackend::FloatTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). pub fn get_float_tensor( &mut self, @@ -81,7 +81,7 @@ impl HandleContainer { ) } - /// Get the [int tensor](Backend::IntTensorPrimitive) corresponding to the + /// Get the [int tensor](ReprBackend::IntTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). pub fn get_int_tensor( &mut self, @@ -96,7 +96,7 @@ impl HandleContainer { ) } - /// Get the [bool tensor](Backend::BoolTensorPrimitive) corresponding to the + /// Get the [bool tensor](ReprBackend::BoolTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). pub fn get_bool_tensor( &mut self, @@ -111,7 +111,7 @@ impl HandleContainer { ) } - /// Register a new [float tensor](Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). + /// Register a new [float tensor](ReprBackend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_float_tensor( &mut self, id: &TensorId, @@ -123,7 +123,7 @@ impl HandleContainer { self.handles.insert(*id, Handle::Existing(handle)); } - /// Register a new [int tensor](Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). + /// Register a new [int tensor](ReprBackend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_int_tensor( &mut self, id: &TensorId, @@ -135,7 +135,7 @@ impl HandleContainer { self.handles.insert(*id, Handle::Existing(handle)); } - /// Register a new [bool tensor](Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). + /// Register a new [bool tensor](ReprBackend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_bool_tensor( &mut self, id: &TensorId, From 53ce575a2c7749c9ce869e571b080fb0650dd01c Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 30 Apr 2024 10:33:04 -0400 Subject: [PATCH 03/25] Remove dependency for the context --- crates/burn-fusion/src/backend.rs | 2 +- crates/burn-fusion/src/stream/context.rs | 10 +++------ crates/burn-jit/src/fusion/base.rs | 2 +- .../src/fusion/elemwise/optimization.rs | 21 ++++++++----------- crates/burn-jit/src/fusion/kernel.rs | 18 +++++----------- 5 files changed, 19 insertions(+), 34 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index 3d4167e5cc..4ae96cf273 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -103,7 +103,7 @@ pub trait OptimizationBuilder: Send { /// The operation created from the [builder](OptimizationBuilder). pub trait Optimization: Send { /// Execute the operation. - fn execute(&mut self, context: &mut Context<'_, B>); + fn execute(&mut self, context: &mut Context<'_, B::Handle>); /// The number of registered operations in this optimization. fn len(&self) -> usize; /// If the current optimization is empty. diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 9b15b36d07..b66c89a4b7 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -1,4 +1,3 @@ -use crate::FusionBackend; use burn_tensor::{repr::*, Element, ElementConversion}; use hashbrown::HashMap; @@ -9,11 +8,11 @@ use hashbrown::HashMap; /// It also contains all scalar values, which can change even for the same graph. They are sorted /// in the order in which they appear in the graph. #[derive(new)] -pub struct Context<'a, B: FusionBackend> { +pub struct Context<'a, H> { /// The tensor mapping where local tensor id points to the updated tensor description. pub tensors: &'a HashMap, /// Handle container to retrieve tensors based on their description. - pub handles: &'a mut HandleContainer, + pub handles: &'a mut HandleContainer, /// Float scalars found in the graph in the order they appeared. pub scalar_floats: &'a Vec, /// Int scalars found in the graph in the order they appeared. @@ -42,10 +41,7 @@ trait RelativeOpsScalar { } impl OperationConverter { - pub(crate) fn context<'a, B: FusionBackend>( - &'a self, - handles: &'a mut HandleContainer, - ) -> Context<'a, B> { + pub(crate) fn context<'a, H>(&'a self, handles: &'a mut HandleContainer) -> Context<'a, H> { Context { handles, tensors: &self.tensors_relative2global, diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index f1a49290e6..f97f92a532 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -32,7 +32,7 @@ where F: FloatElement, I: IntElement, { - fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitBackend>) { + fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle>) { match self { Self::ElementWise(op) => op.execute(context), } diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index ac6010bbdb..afa5198230 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -7,8 +7,8 @@ use super::{ use crate::{ codegen::dialect::gpu::WorkgroupSize, compute::JitAutotuneKey, - fusion::{kernel::FusionKernel, tracing::Trace}, - FloatElement, IntElement, JitBackend, Runtime, + fusion::{kernel::FusionKernel, tracing::Trace, JitFusionHandle}, + Runtime, }; use burn_common::id::IdGenerator; use burn_compute::client::ComputeClient; @@ -66,10 +66,7 @@ impl ElementWise { } impl ElementWise> { - pub(crate) fn execute( - &mut self, - context: &mut Context<'_, JitBackend>, - ) { + pub(crate) fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { let client = R::client(&self.device); let key = JitAutotuneKey::FusionElemWise(FusionElemWiseAutotuneKey::new( @@ -84,9 +81,9 @@ impl ElementWise> { } } - fn run_kernel( + fn run_kernel( &mut self, - context: &mut Context<'_, JitBackend>, + context: &mut Context<'_, JitFusionHandle>, client: ComputeClient, fastest_set_index: usize, ) { @@ -109,9 +106,9 @@ impl ElementWise> { kernel.execute(); } - fn run_autotune( + fn run_autotune( &mut self, - context: &mut Context<'_, JitBackend>, + context: &mut Context<'_, JitFusionHandle>, client: ComputeClient, key: JitAutotuneKey, ) { @@ -155,9 +152,9 @@ impl ElementWise> { } /// The first output is chosen when possible, otherwise the first input is chosen. - pub(crate) fn autotune_shape<'a, F: FloatElement, I: IntElement>( + pub(crate) fn autotune_shape<'a>( &self, - context: &mut Context<'a, JitBackend>, + context: &mut Context<'a, JitFusionHandle>, ) -> &'a [usize] { let info = self.trace.running(); diff --git a/crates/burn-jit/src/fusion/kernel.rs b/crates/burn-jit/src/fusion/kernel.rs index 0304bfeb49..5ada67bdd3 100644 --- a/crates/burn-jit/src/fusion/kernel.rs +++ b/crates/burn-jit/src/fusion/kernel.rs @@ -9,9 +9,6 @@ use crate::fusion::strides_dyn_rank; use crate::fusion::JitFusionHandle; use crate::gpu::ComputeShader; use crate::kernel::GpuComputeShaderPhase; -use crate::FloatElement; -use crate::IntElement; -use crate::JitBackend; use crate::Runtime; use burn_compute::client::ComputeClient; use burn_compute::server::Binding; @@ -19,7 +16,6 @@ use burn_compute::tune::AutotuneOperation; use burn_fusion::stream::Context; use burn_tensor::repr::TensorDescription; use burn_tensor::repr::TensorStatus; -use burn_tensor::Device; use std::marker::PhantomData; use std::sync::Arc; @@ -108,18 +104,16 @@ impl From> for AutotunableKernel { } impl FusionKernel { - pub fn create( + pub fn create( factory: &K, running_info: &ExecutionInfo<'_>, - context: &mut Context<'_, JitBackend>, - device: Device>, + context: &mut Context<'_, JitFusionHandle>, + device: R::Device, client: ComputeClient, stateful: bool, ) -> ExecutableKernel where K: FusionKernelFactory, - F: FloatElement, - I: IntElement, { let (handles_input, inputs_description_updated, outputs_description_updated) = process_inputs_outputs( @@ -273,10 +267,10 @@ fn register_info_tensor( } } -fn process_inputs_outputs<'a, R, F, I>( +fn process_inputs_outputs<'a, R>( inputs: &[&TensorDescription], outputs: &[&TensorDescription], - context: &'a mut Context<'_, JitBackend>, + context: &'a mut Context<'_, JitFusionHandle>, stateful: bool, ) -> ( Vec>, @@ -285,8 +279,6 @@ fn process_inputs_outputs<'a, R, F, I>( ) where R: Runtime, - F: FloatElement, - I: IntElement, { let mut inputs_description_updated = Vec::with_capacity(inputs.len()); let mut outputs_description_updated = Vec::with_capacity(outputs.len()); From 2e31a3404ce35e59e560fa9544900c9868063a25 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 30 Apr 2024 12:51:09 -0400 Subject: [PATCH 04/25] Fusion runtime --- crates/burn-fusion/src/backend.rs | 35 ++- crates/burn-fusion/src/client/base.rs | 9 +- crates/burn-fusion/src/client/mutex.rs | 26 +- crates/burn-fusion/src/ops/binary.rs | 20 +- crates/burn-fusion/src/ops/boolean.rs | 101 +++++--- crates/burn-fusion/src/ops/float.rs | 235 ++++++++++-------- crates/burn-fusion/src/ops/int.rs | 212 +++++++++------- crates/burn-fusion/src/ops/module.rs | 46 ++-- crates/burn-fusion/src/ops/unary.rs | 55 ++-- crates/burn-fusion/src/server.rs | 69 +++-- crates/burn-fusion/src/stream/base.rs | 16 +- .../burn-fusion/src/stream/execution/base.rs | 20 +- crates/burn-fusion/src/stream/multi.rs | 46 ++-- crates/burn-jit/src/fusion/base.rs | 22 +- 14 files changed, 520 insertions(+), 392 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index 4ae96cf273..c17ce0e351 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -11,7 +11,7 @@ use std::marker::PhantomData; pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new(); -pub(crate) fn get_client(device: &B::Device) -> B::FusionClient { +pub(crate) fn get_client(device: &Device) -> B::FusionClient { CLIENTS.client(device) } @@ -101,9 +101,9 @@ pub trait OptimizationBuilder: Send { } /// The operation created from the [builder](OptimizationBuilder). -pub trait Optimization: Send { +pub trait Optimization: Send { /// Execute the operation. - fn execute(&mut self, context: &mut Context<'_, B::Handle>); + fn execute(&mut self, context: &mut Context<'_, R::FusionHandle>); /// The number of registered operations in this optimization. fn len(&self) -> usize; /// If the current optimization is empty. @@ -111,22 +111,37 @@ pub trait Optimization: Send { self.len() == 0 } /// Returns the state that can be serialized. - fn to_state(&self) -> B::OptimizationState; + fn to_state(&self) -> R::OptimizationState; /// Create the optimization from the state. - fn from_state(device: &B::Device, state: B::OptimizationState) -> Self; + fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self; } /// Trait that allows an existing [backend](Backend) to specify graph optimizations using /// [operation builder](crate::OptimizationBuilder). -pub trait FusionBackend: Backend + ReprBackend { +pub trait FusionRuntime: Send + Sync +where + Self: Sized, +{ /// The state that can be serialized for an optimization. type OptimizationState: Serialize + DeserializeOwned; /// Optimization type for the backend. type Optimization: Optimization; - /// What kind of client should be used. - type FusionClient: FusionClient; + /// Handle + type FusionHandle: Clone; + /// Device + type FusionDevice: Clone; /// The list of optimizations that will be used to optimize the computational graph. - fn optimizations(device: Device) - -> Vec>>; + fn optimizations( + device: Self::FusionDevice, + ) -> Vec>>; +} + +/// Trait that allows an existing [backend](Backend) to specify graph optimizations using +/// [operation builder](crate::OptimizationBuilder). +pub trait FusionBackend: ReprBackend { + /// What kind of client should be used. + type FusionClient: FusionClient; + /// The runtime. + type FusionRuntime: FusionRuntime; } diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 377f98b8f0..eca898b5c9 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -17,12 +17,9 @@ pub trait FusionClient: Send + Sync + Clone { /// Create a new client for the given [device](Backend::Device). fn new(device: Device) -> Self; /// Register a new [tensor operation description](OperationDescription). - fn register + 'static>( - &self, - streams: Vec, - description: OperationDescription, - operation: O, - ); + fn register(&self, streams: Vec, description: OperationDescription, operation: O) + where + O: Operation<::FusionRuntime> + 'static; /// Register all lazy computation. fn drain(&self); /// Get the current device used by all operations handled by this client. diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 940d331f0f..6e06963bd5 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -16,7 +16,7 @@ pub struct MutexFusionClient where B: FusionBackend, { - server: Arc>>, + server: Arc>>, device: B::Device, } @@ -45,12 +45,10 @@ where } } - fn register + 'static>( - &self, - streams: Vec, - description: OperationDescription, - operation: O, - ) { + fn register(&self, streams: Vec, description: OperationDescription, operation: O) + where + O: Operation<::FusionRuntime> + 'static, + { self.server .lock() .register(streams, description, Box::new(operation)) @@ -89,7 +87,7 @@ where tensor: TensorDescription, stream: StreamId, ) -> burn_tensor::Reader, D>> { - self.server.lock().read_float(tensor, stream) + self.server.lock().read_float::(tensor, stream) } fn read_tensor_int( @@ -98,7 +96,7 @@ where id: StreamId, ) -> burn_tensor::Reader, D>> { - self.server.lock().read_int(tensor, id) + self.server.lock().read_int::(tensor, id) } fn read_tensor_bool( @@ -106,7 +104,7 @@ where tensor: TensorDescription, stream: StreamId, ) -> burn_tensor::Reader> { - self.server.lock().read_bool(tensor, stream) + self.server.lock().read_bool::(tensor, stream) } fn change_client_float( @@ -120,7 +118,7 @@ where server_current.drain_stream(stream); let id = - server_current.change_server_float::(&tensor, &client.device, &mut server_other); + server_current.change_server_float::(&tensor, &client.device, &mut server_other); core::mem::drop(server_other); core::mem::drop(server_current); @@ -138,7 +136,8 @@ where let mut server_current = self.server.lock(); server_current.drain_stream(stream); - let id = server_current.change_server_int::(&tensor, &client.device, &mut server_other); + let id = + server_current.change_server_int::(&tensor, &client.device, &mut server_other); core::mem::drop(server_other); core::mem::drop(server_current); @@ -156,7 +155,8 @@ where let mut server_current = self.server.lock(); server_current.drain_stream(stream); - let id = server_current.change_server_bool::(&tensor, &client.device, &mut server_other); + let id = + server_current.change_server_bool::(&tensor, &client.device, &mut server_other); core::mem::drop(server_other); core::mem::drop(server_current); diff --git a/crates/burn-fusion/src/ops/binary.rs b/crates/burn-fusion/src/ops/binary.rs index 388c44cd4a..caa6093b61 100644 --- a/crates/burn-fusion/src/ops/binary.rs +++ b/crates/burn-fusion/src/ops/binary.rs @@ -6,11 +6,12 @@ macro_rules! binary_float_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: BinaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer<::Handle>) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let rhs = handles.get_float_tensor::(&self.desc.rhs); @@ -30,11 +31,12 @@ macro_rules! binary_float_cmp_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: BinaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let rhs = handles.get_float_tensor::(&self.desc.rhs); @@ -54,11 +56,12 @@ macro_rules! binary_int_cmp_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: BinaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let rhs = handles.get_int_tensor::(&self.desc.rhs); @@ -88,11 +91,12 @@ macro_rules! binary_int_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: BinaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let rhs = handles.get_int_tensor::(&self.desc.rhs); diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index f1ea9bd49f..c156f2d3cb 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -1,3 +1,5 @@ +use std::marker::PhantomData; + use crate::{ client::FusionClient, get_client, @@ -58,11 +60,12 @@ impl BoolTensorOps for Fusion { tensor: BoolTensor, ) -> burn_tensor::ops::IntTensor { #[derive(new)] - struct IntoIntOps { + struct IntoIntOps { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for IntoIntOps { + impl Operation for IntoIntOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_into_int(input); @@ -81,7 +84,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Bool(BoolOperationDescription::IntoInt(desc.clone())), - IntoIntOps::::new(desc), + IntoIntOps::::new(desc), ); out @@ -91,11 +94,12 @@ impl BoolTensorOps for Fusion { tensor: BoolTensor, ) -> burn_tensor::ops::FloatTensor { #[derive(new)] - struct IntoFloatOps { + struct IntoFloatOps { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for IntoFloatOps { + impl Operation for IntoFloatOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_into_float(input); @@ -113,7 +117,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Bool(BoolOperationDescription::IntoFloat(desc.clone())), - IntoFloatOps::::new(desc), + IntoFloatOps::::new(desc), ); out @@ -150,11 +154,14 @@ impl BoolTensorOps for Fusion { shape: Shape, ) -> BoolTensor { #[derive(new)] - struct ReshapeDimsOps { + struct ReshapeDimsOps { desc: ReshapeDescription, + _b: PhantomData, } - impl Operation for ReshapeDimsOps { + impl Operation + for ReshapeDimsOps + { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_reshape::(input, Shape::from(&self.desc.out.shape)); @@ -173,7 +180,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::Reshape(desc.clone())), - ReshapeDimsOps::::new(desc), + ReshapeDimsOps::::new(desc), ); out @@ -184,11 +191,14 @@ impl BoolTensorOps for Fusion { ranges: [std::ops::Range; D2], ) -> BoolTensor { #[derive(new)] - struct SliceOps { + struct SliceOps { desc: SliceOperationDescription, + _b: PhantomData, } - impl Operation for SliceOps { + impl Operation + for SliceOps + { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); @@ -216,7 +226,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::Slice(desc.clone())), - SliceOps::::new(desc), + SliceOps::::new(desc), ); out @@ -228,11 +238,14 @@ impl BoolTensorOps for Fusion { value: BoolTensor, ) -> BoolTensor { #[derive(new)] - struct SliceAssignOps { + struct SliceAssignOps { desc: SliceAssignOperationDescription, + _b: PhantomData, } - impl Operation for SliceAssignOps { + impl Operation + for SliceAssignOps + { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); let value = handles.get_bool_tensor::(&self.desc.value); @@ -262,7 +275,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseBool(BaseOperationDescription::SliceAssign(desc.clone())), - SliceAssignOps::::new(desc), + SliceAssignOps::::new(desc), ); out @@ -273,11 +286,12 @@ impl BoolTensorOps for Fusion { dim: usize, ) -> BoolTensor { #[derive(new)] - struct CatOps { + struct CatOps { desc: CatOperationDescription, + _b: PhantomData, } - impl Operation for CatOps { + impl Operation for CatOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc @@ -314,7 +328,7 @@ impl BoolTensorOps for Fusion { client.register( streams, OperationDescription::BaseBool(BaseOperationDescription::Cat(desc.clone())), - CatOps::::new(desc), + CatOps::::new(desc), ); out @@ -325,11 +339,12 @@ impl BoolTensorOps for Fusion { rhs: BoolTensor, ) -> BoolTensor { #[derive(new)] - struct EqualOps { + struct EqualOps { desc: BinaryOperationDescription, + _b: PhantomData, } - impl Operation for EqualOps { + impl Operation for EqualOps { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_bool_tensor::(&self.desc.lhs); let rhs = handles.get_bool_tensor::(&self.desc.rhs); @@ -352,7 +367,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseBool(BaseOperationDescription::Equal(desc.clone())), - EqualOps::::new(desc), + EqualOps::::new(desc), ); out @@ -360,11 +375,12 @@ impl BoolTensorOps for Fusion { fn bool_not(tensor: BoolTensor) -> BoolTensor { #[derive(new)] - struct NotOps { + struct NotOps { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for NotOps { + impl Operation for NotOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_not(input); @@ -383,7 +399,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Bool(BoolOperationDescription::Not(desc.clone())), - NotOps::::new(desc), + NotOps::::new(desc), ); out @@ -395,11 +411,12 @@ impl BoolTensorOps for Fusion { dim2: usize, ) -> BoolTensor { #[derive(new)] - struct SwapDimsOps { + struct SwapDimsOps { desc: SwapDimsDescription, + _b: PhantomData, } - impl Operation for SwapDimsOps { + impl Operation for SwapDimsOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2); @@ -423,7 +440,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::SwapDims(desc.clone())), - SwapDimsOps::::new(desc), + SwapDimsOps::::new(desc), ); out @@ -434,11 +451,12 @@ impl BoolTensorOps for Fusion { axes: [usize; D], ) -> BoolTensor { #[derive(new)] - struct PermuteDimsOps { + struct PermuteDimsOps { desc: PermuteOperationDescription, + _b: PhantomData, } - impl Operation for PermuteDimsOps { + impl Operation for PermuteDimsOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); @@ -463,7 +481,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Permute(desc.clone())), - PermuteDimsOps::::new(desc), + PermuteDimsOps::::new(desc), ); out @@ -474,11 +492,14 @@ impl BoolTensorOps for Fusion { shape: Shape, ) -> BoolTensor { #[derive(new)] - struct ExpandOps { + struct ExpandOps { desc: ExpandOperationDescription, + _b: PhantomData, } - impl Operation for ExpandOps { + impl Operation + for ExpandOps + { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); @@ -501,7 +522,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::Expand(desc.clone())), - ExpandOps::::new(desc), + ExpandOps::::new(desc), ); out @@ -512,11 +533,12 @@ impl BoolTensorOps for Fusion { axes: &[usize], ) -> BoolTensor { #[derive(new)] - struct FlipOps { + struct FlipOps { desc: FlipOperationDescription, + _b: PhantomData, } - impl Operation for FlipOps { + impl Operation for FlipOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_flip(input, self.desc.axes.as_slice()); @@ -536,7 +558,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::Flip(desc.clone())), - FlipOps::::new(desc), + FlipOps::::new(desc), ); out @@ -548,11 +570,12 @@ impl BoolTensorOps for Fusion { times: usize, ) -> BoolTensor { #[derive(new)] - struct RepeatOps { + struct RepeatOps { desc: RepeatOperationDescription, + _b: PhantomData, } - impl Operation for RepeatOps { + impl Operation for RepeatOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); @@ -576,7 +599,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::Repeat(desc.clone())), - RepeatOps::::new(desc), + RepeatOps::::new(desc), ); out diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index e220cf79e3..eb478af67b 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -12,7 +12,7 @@ use burn_tensor::{ repr::*, Data, Device, Distribution, ElementConversion, Reader, Shape, }; -use std::ops::Range; +use std::{marker::PhantomData, ops::Range}; impl FloatTensorOps for Fusion { fn float_from_data( @@ -41,7 +41,7 @@ impl FloatTensorOps for Fusion { device: Device, } - impl Operation for RandomOps { + impl Operation for RandomOps { fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.out.shape.clone()); let output: B::FloatTensorPrimitive = @@ -75,7 +75,7 @@ impl FloatTensorOps for Fusion { device: Device, } - impl Operation for ZerosOps { + impl Operation for ZerosOps { fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); let output = B::float_zeros::(shape, &self.device); @@ -105,7 +105,7 @@ impl FloatTensorOps for Fusion { device: Device, } - impl Operation for OnesOps { + impl Operation for OnesOps { fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); let output = B::float_ones::(shape, &self.device); @@ -140,7 +140,7 @@ impl FloatTensorOps for Fusion { device: Device, } - impl Operation for FullOps { + impl Operation for FullOps { fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); let output: B::FloatTensorPrimitive = @@ -202,11 +202,12 @@ impl FloatTensorOps for Fusion { fn float_into_int(tensor: FloatTensor) -> IntTensor { #[derive(new)] - struct IntoIntOps { + struct IntoIntOps { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for IntoIntOps { + impl Operation for IntoIntOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_into_int(input); @@ -225,7 +226,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::IntoInt(desc.clone())), - IntoIntOps::::new(desc), + IntoIntOps::::new(desc), ); out @@ -260,7 +261,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Add(desc.clone())), - AddOps::::new(desc), + AddOps::::new(desc), ); out @@ -285,7 +286,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::AddScalar( desc.clone(), )), - AddOps::::new(desc), + AddOps::::new(desc), ); out @@ -297,11 +298,12 @@ impl FloatTensorOps for Fusion { max: FloatElem, ) -> FloatTensor { #[derive(new)] - struct ClampOps { + struct ClampOps { desc: ClampOperationDescription, + _b: PhantomData, } - impl Operation for ClampOps { + impl Operation for ClampOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_clamp(input, self.desc.min.elem(), self.desc.max.elem()); @@ -322,7 +324,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Clamp(desc.clone())), - ClampOps::::new(desc), + ClampOps::::new(desc), ); out @@ -348,7 +350,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Sub(desc.clone())), - SubOps::::new(desc), + SubOps::::new(desc), ); out @@ -373,7 +375,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::SubScalar( desc.clone(), )), - SubOps::::new(desc), + SubOps::::new(desc), ); out @@ -399,7 +401,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Mul(desc.clone())), - MulOps::::new(desc), + MulOps::::new(desc), ); out @@ -424,7 +426,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::MulScalar( desc.clone(), )), - MulOps::::new(desc), + MulOps::::new(desc), ); out @@ -450,7 +452,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Div(desc.clone())), - DivOps::::new(desc), + DivOps::::new(desc), ); out @@ -475,7 +477,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::DivScalar( desc.clone(), )), - DivOps::::new(desc), + DivOps::::new(desc), ); out @@ -500,7 +502,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::RemScalar( desc.clone(), )), - ModOps::::new(desc), + ModOps::::new(desc), ); out @@ -529,7 +531,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::Float(FloatOperationDescription::Matmul(desc.clone())), - MatmulOps::::new(desc), + MatmulOps::::new(desc), ); out @@ -541,11 +543,12 @@ impl FloatTensorOps for Fusion { dim2: usize, ) -> FloatTensor { #[derive(new)] - struct SwapDimsOps { + struct SwapDimsOps { desc: SwapDimsDescription, + _b: PhantomData, } - impl Operation for SwapDimsOps { + impl Operation for SwapDimsOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2); @@ -569,7 +572,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseFloat(BaseOperationDescription::SwapDims(desc.clone())), - SwapDimsOps::::new(desc), + SwapDimsOps::::new(desc), ); out.stream = stream; @@ -581,11 +584,14 @@ impl FloatTensorOps for Fusion { shape: Shape, ) -> FloatTensor { #[derive(new)] - struct ReshapeDimsOps { + struct ReshapeDimsOps { desc: ReshapeDescription, + _b: PhantomData, } - impl Operation for ReshapeDimsOps { + impl Operation + for ReshapeDimsOps + { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_reshape::(input, Shape::from(&self.desc.out.shape)); @@ -604,7 +610,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseFloat(BaseOperationDescription::Reshape(desc.clone())), - ReshapeDimsOps::::new(desc), + ReshapeDimsOps::::new(desc), ); out @@ -616,11 +622,12 @@ impl FloatTensorOps for Fusion { indices: IntTensor, ) -> FloatTensor { #[derive(new)] - struct GatherOps { + struct GatherOps { desc: GatherOperationDescription, + _b: PhantomData, } - impl Operation for GatherOps { + impl Operation for GatherOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); @@ -644,7 +651,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Gather(desc.clone())), - GatherOps::::new(desc), + GatherOps::::new(desc), ); out @@ -657,11 +664,12 @@ impl FloatTensorOps for Fusion { value: FloatTensor, ) -> FloatTensor { #[derive(new)] - struct ScatterOps { + struct ScatterOps { desc: ScatterOperationDescription, + _b: PhantomData, } - impl Operation for ScatterOps { + impl Operation for ScatterOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); @@ -690,7 +698,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2, stream_3], OperationDescription::NumericFloat(NumericOperationDescription::Scatter(desc.clone())), - ScatterOps::::new(desc), + ScatterOps::::new(desc), ); out @@ -702,11 +710,12 @@ impl FloatTensorOps for Fusion { indices: IntTensor, ) -> FloatTensor { #[derive(new)] - struct SelectOps { + struct SelectOps { desc: SelectOperationDescription, + _b: PhantomData, } - impl Operation for SelectOps { + impl Operation for SelectOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); @@ -731,7 +740,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Select(desc.clone())), - SelectOps::::new(desc), + SelectOps::::new(desc), ); out @@ -744,11 +753,12 @@ impl FloatTensorOps for Fusion { value: FloatTensor, ) -> FloatTensor { #[derive(new)] - struct SelectAssignOps { + struct SelectAssignOps { desc: SelectAssignOperationDescription, + _b: PhantomData, } - impl Operation for SelectAssignOps { + impl Operation for SelectAssignOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); @@ -778,7 +788,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::SelectAssign( desc.clone(), )), - SelectAssignOps::::new(desc), + SelectAssignOps::::new(desc), ); out @@ -789,11 +799,14 @@ impl FloatTensorOps for Fusion { ranges: [Range; D2], ) -> FloatTensor { #[derive(new)] - struct SliceOps { + struct SliceOps { desc: SliceOperationDescription, + _b: PhantomData, } - impl Operation for SliceOps { + impl Operation + for SliceOps + { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); @@ -820,7 +833,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseFloat(BaseOperationDescription::Slice(desc.clone())), - SliceOps::::new(desc), + SliceOps::::new(desc), ); out @@ -832,11 +845,14 @@ impl FloatTensorOps for Fusion { value: FloatTensor, ) -> FloatTensor { #[derive(new)] - struct SliceAssignOps { + struct SliceAssignOps { desc: SliceAssignOperationDescription, + _b: PhantomData, } - impl Operation for SliceAssignOps { + impl Operation + for SliceAssignOps + { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let value = handles.get_float_tensor::(&self.desc.value); @@ -865,7 +881,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseFloat(BaseOperationDescription::SliceAssign(desc.clone())), - SliceAssignOps::::new(desc), + SliceAssignOps::::new(desc), ); out @@ -877,11 +893,12 @@ impl FloatTensorOps for Fusion { value: FloatTensor, ) -> FloatTensor { #[derive(new)] - struct MaskWhereOps { + struct MaskWhereOps { desc: MaskWhereOperationDescription, + _b: PhantomData, } - impl Operation for MaskWhereOps { + impl Operation for MaskWhereOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let value = handles.get_float_tensor::(&self.desc.value); @@ -910,7 +927,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::MaskWhere( desc.clone(), )), - MaskWhereOps::::new(desc), + MaskWhereOps::::new(desc), ); out @@ -922,11 +939,12 @@ impl FloatTensorOps for Fusion { value: FloatElem, ) -> FloatTensor { #[derive(new)] - struct MaskFillOps { + struct MaskFillOps { desc: MaskFillOperationDescription, + _b: PhantomData, } - impl Operation for MaskFillOps { + impl Operation for MaskFillOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let mask = handles.get_bool_tensor::(&self.desc.mask); @@ -950,7 +968,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::MaskFill(desc.clone())), - MaskFillOps::::new(desc), + MaskFillOps::::new(desc), ); out @@ -976,7 +994,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseFloat(BaseOperationDescription::Equal(desc.clone())), - EqualOps::::new(desc), + EqualOps::::new(desc), ); out @@ -1001,7 +1019,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::EqualElem( desc.clone(), )), - EqualElemOps::::new(desc), + EqualElemOps::::new(desc), ); out @@ -1027,7 +1045,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Greater(desc.clone())), - GreaterOps::::new(desc), + GreaterOps::::new(desc), ); out @@ -1052,7 +1070,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::GreaterElem( desc.clone(), )), - GreaterElemOps::::new(desc), + GreaterElemOps::::new(desc), ); out @@ -1080,7 +1098,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::GreaterEqual( desc.clone(), )), - GreaterEqualOps::::new(desc), + GreaterEqualOps::::new(desc), ); out @@ -1105,7 +1123,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::GreaterEqualElem( desc.clone(), )), - GreaterEqualElemOps::::new(desc), + GreaterEqualElemOps::::new(desc), ); out @@ -1131,7 +1149,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Lower(desc.clone())), - LowerOps::::new(desc), + LowerOps::::new(desc), ); out @@ -1156,7 +1174,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::LowerElem( desc.clone(), )), - LowerElemOps::::new(desc), + LowerElemOps::::new(desc), ); out @@ -1184,7 +1202,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::LowerEqual( desc.clone(), )), - LowerEqualOps::::new(desc), + LowerEqualOps::::new(desc), ); out @@ -1209,7 +1227,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::LowerEqualElem( desc.clone(), )), - LowerEqualElemOps::::new(desc), + LowerEqualElemOps::::new(desc), ); out @@ -1228,7 +1246,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Sum(desc.clone())), - SumOps::::new(desc), + SumOps::::new(desc), ); out @@ -1253,7 +1271,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::SumDim(desc.clone())), - SumDimOps::::new(desc), + SumDimOps::::new(desc), ); out @@ -1272,7 +1290,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Mean(desc.clone())), - MeanOps::::new(desc), + MeanOps::::new(desc), ); out @@ -1297,7 +1315,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::MeanDim(desc.clone())), - MeanDimOps::::new(desc), + MeanDimOps::::new(desc), ); out @@ -1316,7 +1334,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Exp(desc.clone())), - ExpOps::::new(desc), + ExpOps::::new(desc), ); out @@ -1335,7 +1353,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Log(desc.clone())), - LogOps::::new(desc), + LogOps::::new(desc), ); out @@ -1354,7 +1372,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Log1p(desc.clone())), - Log1pOps::::new(desc), + Log1pOps::::new(desc), ); out @@ -1377,7 +1395,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::PowfScalar(desc.clone())), - PowfOps::::new(desc), + PowfOps::::new(desc), ); out @@ -1396,7 +1414,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Sqrt(desc.clone())), - SqrtOps::::new(desc), + SqrtOps::::new(desc), ); out @@ -1415,7 +1433,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Abs(desc.clone())), - AbsOps::::new(desc), + AbsOps::::new(desc), ); out @@ -1434,7 +1452,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Cos(desc.clone())), - CosOps::::new(desc), + CosOps::::new(desc), ); out @@ -1453,7 +1471,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Sin(desc.clone())), - SinOps::::new(desc), + SinOps::::new(desc), ); out @@ -1472,7 +1490,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Tanh(desc.clone())), - TanhOps::::new(desc), + TanhOps::::new(desc), ); out @@ -1490,7 +1508,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Recip(desc.clone())), - Recip::::new(desc), + Recip::::new(desc), ); out @@ -1509,7 +1527,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Erf(desc.clone())), - TanhOps::::new(desc), + TanhOps::::new(desc), ); out @@ -1520,11 +1538,12 @@ impl FloatTensorOps for Fusion { dim: usize, ) -> FloatTensor { #[derive(new)] - struct CatOps { + struct CatOps { desc: CatOperationDescription, + _b: PhantomData, } - impl Operation for CatOps { + impl Operation for CatOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc @@ -1560,7 +1579,7 @@ impl FloatTensorOps for Fusion { client.register( streams, OperationDescription::BaseFloat(BaseOperationDescription::Cat(desc.clone())), - CatOps::::new(desc), + CatOps::::new(desc), ); out @@ -1585,7 +1604,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::ArgMax(desc.clone())), - ArgMaxOps::::new(desc), + ArgMaxOps::::new(desc), ); out @@ -1597,11 +1616,12 @@ impl FloatTensorOps for Fusion { times: usize, ) -> FloatTensor { #[derive(new)] - struct RepeatOps { + struct RepeatOps { desc: RepeatOperationDescription, + _b: PhantomData, } - impl Operation for RepeatOps { + impl Operation for RepeatOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); @@ -1625,7 +1645,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseFloat(BaseOperationDescription::Repeat(desc.clone())), - RepeatOps::::new(desc), + RepeatOps::::new(desc), ); out @@ -1650,7 +1670,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::ArgMin(desc.clone())), - ArgMinOps::::new(desc), + ArgMinOps::::new(desc), ); out @@ -1669,7 +1689,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Max(desc.clone())), - MaxOps::::new(desc), + MaxOps::::new(desc), ); out @@ -1694,7 +1714,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::MaxDim(desc.clone())), - MaxDimOps::::new(desc), + MaxDimOps::::new(desc), ); out @@ -1705,11 +1725,12 @@ impl FloatTensorOps for Fusion { dim: usize, ) -> (FloatTensor, IntTensor) { #[derive(new)] - struct MaxDimWithIndicesOps { + struct MaxDimWithIndicesOps { desc: ReduceDimWithIndicesDescription, + _b: PhantomData, } - impl Operation for MaxDimWithIndicesOps { + impl Operation for MaxDimWithIndicesOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let (output, indices) = B::float_max_dim_with_indices(tensor, self.desc.dim); @@ -1737,7 +1758,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::MaxDimWithIndices( desc.clone(), )), - MaxDimWithIndicesOps::::new(desc), + MaxDimWithIndicesOps::::new(desc), ); (out, out_indices) @@ -1756,7 +1777,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Min(desc.clone())), - MinOps::::new(desc), + MinOps::::new(desc), ); out @@ -1781,7 +1802,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::MinDim(desc.clone())), - MinDimOps::::new(desc), + MinDimOps::::new(desc), ); out @@ -1792,11 +1813,12 @@ impl FloatTensorOps for Fusion { dim: usize, ) -> (FloatTensor, IntTensor) { #[derive(new)] - struct MinDimWithIndicesOps { + struct MinDimWithIndicesOps { desc: ReduceDimWithIndicesDescription, + _b: PhantomData, } - impl Operation for MinDimWithIndicesOps { + impl Operation for MinDimWithIndicesOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let (output, indices) = B::float_min_dim_with_indices(tensor, self.desc.dim); @@ -1824,7 +1846,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::MinDimWithIndices( desc.clone(), )), - MinDimWithIndicesOps::::new(desc), + MinDimWithIndicesOps::::new(desc), ); (out, out_indices) @@ -1850,7 +1872,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Powf(desc.clone())), - PowOps::::new(desc), + PowOps::::new(desc), ); out @@ -1861,11 +1883,12 @@ impl FloatTensorOps for Fusion { axes: [usize; D], ) -> FloatTensor { #[derive(new)] - struct PermuteDimsOps { + struct PermuteDimsOps { desc: PermuteOperationDescription, + _b: PhantomData, } - impl Operation for PermuteDimsOps { + impl Operation for PermuteDimsOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); @@ -1890,7 +1913,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Permute(desc.clone())), - PermuteDimsOps::::new(desc), + PermuteDimsOps::::new(desc), ); out @@ -1901,11 +1924,14 @@ impl FloatTensorOps for Fusion { shape: Shape, ) -> FloatTensor { #[derive(new)] - struct ExpandOps { + struct ExpandOps { desc: ExpandOperationDescription, + _b: PhantomData, } - impl Operation for ExpandOps { + impl Operation + for ExpandOps + { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); @@ -1928,7 +1954,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseFloat(BaseOperationDescription::Expand(desc.clone())), - ExpandOps::::new(desc), + ExpandOps::::new(desc), ); out @@ -1939,11 +1965,12 @@ impl FloatTensorOps for Fusion { axes: &[usize], ) -> FloatTensor { #[derive(new)] - struct FlipOps { + struct FlipOps { desc: FlipOperationDescription, + _b: PhantomData, } - impl Operation for FlipOps { + impl Operation for FlipOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_flip(input, &self.desc.axes); @@ -1963,7 +1990,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Flip(desc.clone())), - FlipOps::::new(desc), + FlipOps::::new(desc), ); out diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 4ebd0b7fe0..efad010d6d 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -13,6 +13,7 @@ use burn_tensor::{ Data, Device, Distribution, ElementConversion, Reader, Shape, }; use core::ops::Range; +use std::marker::PhantomData; impl IntTensorOps for Fusion { fn int_empty(shape: Shape, device: &Device) -> IntTensor { @@ -72,11 +73,14 @@ impl IntTensorOps for Fusion { shape: Shape, ) -> IntTensor { #[derive(new)] - struct ReshapeDimsOps { + struct ReshapeDimsOps { desc: ReshapeDescription, + _b: PhantomData, } - impl Operation for ReshapeDimsOps { + impl Operation + for ReshapeDimsOps + { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_reshape::(input, Shape::from(&self.desc.out.shape)); @@ -95,7 +99,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Reshape(desc.clone())), - ReshapeDimsOps::::new(desc), + ReshapeDimsOps::::new(desc), ); out @@ -106,11 +110,14 @@ impl IntTensorOps for Fusion { ranges: [Range; D2], ) -> IntTensor { #[derive(new)] - struct SliceOps { + struct SliceOps { desc: SliceOperationDescription, + _b: PhantomData, } - impl Operation for SliceOps { + impl Operation + for SliceOps + { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); @@ -138,7 +145,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Slice(desc.clone())), - SliceOps::::new(desc), + SliceOps::::new(desc), ); out @@ -150,11 +157,14 @@ impl IntTensorOps for Fusion { value: IntTensor, ) -> IntTensor { #[derive(new)] - struct SliceAssignOps { + struct SliceAssignOps { desc: SliceAssignOperationDescription, + _b: PhantomData, } - impl Operation for SliceAssignOps { + impl Operation + for SliceAssignOps + { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let value = handles.get_int_tensor::(&self.desc.value); @@ -182,7 +192,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseInt(BaseOperationDescription::SliceAssign(desc.clone())), - SliceAssignOps::::new(desc), + SliceAssignOps::::new(desc), ); out @@ -194,11 +204,12 @@ impl IntTensorOps for Fusion { value: IntTensor, ) -> IntTensor { #[derive(new)] - struct MaskWhereOps { + struct MaskWhereOps { desc: MaskWhereOperationDescription, + _b: PhantomData, } - impl Operation for MaskWhereOps { + impl Operation for MaskWhereOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let value = handles.get_int_tensor::(&self.desc.value); @@ -225,7 +236,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2, stream_3], OperationDescription::NumericInt(NumericOperationDescription::MaskWhere(desc.clone())), - MaskWhereOps::::new(desc), + MaskWhereOps::::new(desc), ); out @@ -237,11 +248,12 @@ impl IntTensorOps for Fusion { value: IntElem, ) -> IntTensor { #[derive(new)] - struct MaskFillOps { + struct MaskFillOps { desc: MaskFillOperationDescription, + _b: PhantomData, } - impl Operation for MaskFillOps { + impl Operation for MaskFillOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let mask = handles.get_bool_tensor::(&self.desc.mask); @@ -265,7 +277,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::MaskFill(desc.clone())), - MaskFillOps::::new(desc), + MaskFillOps::::new(desc), ); out @@ -277,11 +289,12 @@ impl IntTensorOps for Fusion { indices: IntTensor, ) -> IntTensor { #[derive(new)] - struct GatherOps { + struct GatherOps { desc: GatherOperationDescription, + _b: PhantomData, } - impl Operation for GatherOps { + impl Operation for GatherOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); @@ -304,7 +317,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::Gather(desc.clone())), - GatherOps::::new(desc), + GatherOps::::new(desc), ); out @@ -317,11 +330,12 @@ impl IntTensorOps for Fusion { value: IntTensor, ) -> IntTensor { #[derive(new)] - struct ScatterOps { + struct ScatterOps { desc: ScatterOperationDescription, + _b: PhantomData, } - impl Operation for ScatterOps { + impl Operation for ScatterOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); @@ -348,7 +362,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2, stream_3], OperationDescription::NumericInt(NumericOperationDescription::Scatter(desc.clone())), - ScatterOps::::new(desc), + ScatterOps::::new(desc), ); out @@ -360,11 +374,12 @@ impl IntTensorOps for Fusion { indices: IntTensor, ) -> IntTensor { #[derive(new)] - struct SelectOps { + struct SelectOps { desc: SelectOperationDescription, + _b: PhantomData, } - impl Operation for SelectOps { + impl Operation for SelectOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); @@ -389,7 +404,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::Select(desc.clone())), - SelectOps::::new(desc), + SelectOps::::new(desc), ); out @@ -402,11 +417,12 @@ impl IntTensorOps for Fusion { value: IntTensor, ) -> IntTensor { #[derive(new)] - struct SelectAssignOps { + struct SelectAssignOps { desc: SelectAssignOperationDescription, + _b: PhantomData, } - impl Operation for SelectAssignOps { + impl Operation for SelectAssignOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); @@ -435,7 +451,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::SelectAssign( desc.clone(), )), - SelectAssignOps::::new(desc), + SelectAssignOps::::new(desc), ); out @@ -443,11 +459,12 @@ impl IntTensorOps for Fusion { fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { #[derive(new)] - struct CatOps { + struct CatOps { desc: CatOperationDescription, + _b: PhantomData, } - impl Operation for CatOps { + impl Operation for CatOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc @@ -483,7 +500,7 @@ impl IntTensorOps for Fusion { client.register( streams, OperationDescription::BaseInt(BaseOperationDescription::Cat(desc.clone())), - CatOps::::new(desc), + CatOps::::new(desc), ); out @@ -509,7 +526,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseInt(BaseOperationDescription::Equal(desc.clone())), - EqualOps::::new(desc), + EqualOps::::new(desc), ); out @@ -532,7 +549,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::EqualElem(desc.clone())), - EqualElemOps::::new(desc), + EqualElemOps::::new(desc), ); out @@ -558,7 +575,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::Greater(desc.clone())), - GreaterOps::::new(desc), + GreaterOps::::new(desc), ); out @@ -583,7 +600,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::GreaterElem( desc.clone(), )), - GreaterElemOps::::new(desc), + GreaterElemOps::::new(desc), ); out @@ -611,7 +628,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::GreaterEqual( desc.clone(), )), - GreaterEqualOps::::new(desc), + GreaterEqualOps::::new(desc), ); out @@ -636,7 +653,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::GreaterEqualElem( desc.clone(), )), - GreaterEqualElemOps::::new(desc), + GreaterEqualElemOps::::new(desc), ); out @@ -662,7 +679,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::Lower(desc.clone())), - LowerOps::::new(desc), + LowerOps::::new(desc), ); out @@ -685,7 +702,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::LowerElem(desc.clone())), - LowerElemOps::::new(desc), + LowerElemOps::::new(desc), ); out @@ -711,7 +728,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::LowerEqual(desc.clone())), - LowerEqualOps::::new(desc), + LowerEqualOps::::new(desc), ); out @@ -736,7 +753,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::LowerEqualElem( desc.clone(), )), - LowerEqualElemOps::::new(desc), + LowerEqualElemOps::::new(desc), ); out @@ -762,7 +779,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], repr::OperationDescription::NumericInt(NumericOperationDescription::Add(desc.clone())), - AddOps::::new(desc), + AddOps::::new(desc), ); out @@ -787,7 +804,7 @@ impl IntTensorOps for Fusion { repr::OperationDescription::NumericInt(NumericOperationDescription::AddScalar( desc.clone(), )), - AddOps::::new(desc), + AddOps::::new(desc), ); out @@ -813,7 +830,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], repr::OperationDescription::NumericInt(NumericOperationDescription::Sub(desc.clone())), - SubOps::::new(desc), + SubOps::::new(desc), ); out @@ -838,7 +855,7 @@ impl IntTensorOps for Fusion { repr::OperationDescription::NumericInt(NumericOperationDescription::SubScalar( desc.clone(), )), - SubOps::::new(desc), + SubOps::::new(desc), ); out @@ -864,7 +881,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], repr::OperationDescription::NumericInt(NumericOperationDescription::Mul(desc.clone())), - MulOps::::new(desc), + MulOps::::new(desc), ); out @@ -889,7 +906,7 @@ impl IntTensorOps for Fusion { repr::OperationDescription::NumericInt(NumericOperationDescription::MulScalar( desc.clone(), )), - MulOps::::new(desc), + MulOps::::new(desc), ); out @@ -915,7 +932,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], repr::OperationDescription::NumericInt(NumericOperationDescription::Div(desc.clone())), - DivOps::::new(desc), + DivOps::::new(desc), ); out @@ -940,7 +957,7 @@ impl IntTensorOps for Fusion { repr::OperationDescription::NumericInt(NumericOperationDescription::DivScalar( desc.clone(), )), - DivOps::::new(desc), + DivOps::::new(desc), ); out @@ -965,7 +982,7 @@ impl IntTensorOps for Fusion { repr::OperationDescription::NumericInt(NumericOperationDescription::RemScalar( desc.clone(), )), - ModOps::::new(desc), + ModOps::::new(desc), ); out @@ -978,7 +995,7 @@ impl IntTensorOps for Fusion { device: Device, } - impl Operation for ZerosOps { + impl Operation for ZerosOps { fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.shape.clone()); let output = B::int_zeros::(shape, &self.device); @@ -1007,7 +1024,7 @@ impl IntTensorOps for Fusion { device: Device, } - impl Operation for OnesOps { + impl Operation for OnesOps { fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.shape.clone()); let output = B::int_ones::(shape, &self.device); @@ -1043,7 +1060,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Sum(desc.clone())), - SumOps::::new(desc), + SumOps::::new(desc), ); out @@ -1065,7 +1082,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::SumDim(desc.clone())), - SumDimOps::::new(desc), + SumDimOps::::new(desc), ); out @@ -1084,7 +1101,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Prod(desc.clone())), - ProdOps::::new(desc), + ProdOps::::new(desc), ); out @@ -1106,7 +1123,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::ProdDim(desc.clone())), - ProdDimOps::::new(desc), + ProdDimOps::::new(desc), ); out @@ -1125,7 +1142,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Mean(desc.clone())), - MeanOps::::new(desc), + MeanOps::::new(desc), ); out @@ -1147,7 +1164,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::MeanDim(desc.clone())), - MeanDimOps::::new(desc), + MeanDimOps::::new(desc), ); out @@ -1169,7 +1186,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::ArgMax(desc.clone())), - ArgMaxOps::::new(desc), + ArgMaxOps::::new(desc), ); out @@ -1191,7 +1208,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::ArgMin(desc.clone())), - ArgMinOps::::new(desc), + ArgMinOps::::new(desc), ); out @@ -1203,11 +1220,12 @@ impl IntTensorOps for Fusion { max: IntElem, ) -> IntTensor { #[derive(new)] - struct ClampOps { + struct ClampOps { desc: ClampOperationDescription, + _b: PhantomData, } - impl Operation for ClampOps { + impl Operation for ClampOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_clamp(input, self.desc.min.elem(), self.desc.max.elem()); @@ -1227,7 +1245,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Clamp(desc.clone())), - ClampOps::::new(desc), + ClampOps::::new(desc), ); out @@ -1246,7 +1264,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Abs(desc.clone())), - AbsOps::::new(desc), + AbsOps::::new(desc), ); out @@ -1254,11 +1272,12 @@ impl IntTensorOps for Fusion { fn int_into_float(tensor: IntTensor) -> FloatTensor { #[derive(new)] - struct IntoFloatOps { + struct IntoFloatOps { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for IntoFloatOps { + impl Operation for IntoFloatOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_into_float(input); @@ -1275,7 +1294,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Int(repr::IntOperationDescription::IntoFloat(desc.clone())), - IntoFloatOps::::new(desc), + IntoFloatOps::::new(desc), ); out @@ -1287,11 +1306,12 @@ impl IntTensorOps for Fusion { dim2: usize, ) -> IntTensor { #[derive(new)] - struct SwapDimsOps { + struct SwapDimsOps { desc: SwapDimsDescription, + _b: PhantomData, } - impl Operation for SwapDimsOps { + impl Operation for SwapDimsOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2); @@ -1315,7 +1335,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::SwapDims(desc.clone())), - SwapDimsOps::::new(desc), + SwapDimsOps::::new(desc), ); out @@ -1334,7 +1354,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Max(desc.clone())), - MaxOps::::new(desc), + MaxOps::::new(desc), ); out @@ -1356,7 +1376,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::MaxDim(desc.clone())), - MaxDimOps::::new(desc), + MaxDimOps::::new(desc), ); out @@ -1367,11 +1387,12 @@ impl IntTensorOps for Fusion { dim: usize, ) -> (IntTensor, IntTensor) { #[derive(new)] - struct MaxDimWithIndicesOps { + struct MaxDimWithIndicesOps { desc: ReduceDimWithIndicesDescription, + _b: PhantomData, } - impl Operation for MaxDimWithIndicesOps { + impl Operation for MaxDimWithIndicesOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim); @@ -1398,7 +1419,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::MaxDimWithIndices( desc.clone(), )), - MaxDimWithIndicesOps::::new(desc), + MaxDimWithIndicesOps::::new(desc), ); (out, out_indices) @@ -1417,7 +1438,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Min(desc.clone())), - MinOps::::new(desc), + MinOps::::new(desc), ); out @@ -1439,7 +1460,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::MinDim(desc.clone())), - MinDimOps::::new(desc), + MinDimOps::::new(desc), ); out @@ -1450,11 +1471,12 @@ impl IntTensorOps for Fusion { dim: usize, ) -> (IntTensor, IntTensor) { #[derive(new)] - struct MinDimWithIndicesOps { + struct MinDimWithIndicesOps { desc: ReduceDimWithIndicesDescription, + _b: PhantomData, } - impl Operation for MinDimWithIndicesOps { + impl Operation for MinDimWithIndicesOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim); @@ -1481,7 +1503,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::MinDimWithIndices( desc.clone(), )), - MinDimWithIndicesOps::::new(desc), + MinDimWithIndicesOps::::new(desc), ); (out, out_indices) @@ -1498,7 +1520,7 @@ impl IntTensorOps for Fusion { device: Device, } - impl Operation for IntRandomOps { + impl Operation for IntRandomOps { fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.out.shape.clone()); let output: B::IntTensorPrimitive = @@ -1530,11 +1552,12 @@ impl IntTensorOps for Fusion { axes: [usize; D], ) -> IntTensor { #[derive(new)] - struct PermuteDimsOps { + struct PermuteDimsOps { desc: PermuteOperationDescription, + _b: PhantomData, } - impl Operation for PermuteDimsOps { + impl Operation for PermuteDimsOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); @@ -1559,7 +1582,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Permute(desc.clone())), - PermuteDimsOps::::new(desc), + PermuteDimsOps::::new(desc), ); out @@ -1570,11 +1593,14 @@ impl IntTensorOps for Fusion { shape: Shape, ) -> IntTensor { #[derive(new)] - struct ExpandOps { + struct ExpandOps { desc: ExpandOperationDescription, + _b: PhantomData, } - impl Operation for ExpandOps { + impl Operation + for ExpandOps + { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); @@ -1596,7 +1622,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Expand(desc.clone())), - ExpandOps::::new(desc), + ExpandOps::::new(desc), ); out @@ -1604,11 +1630,12 @@ impl IntTensorOps for Fusion { fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { #[derive(new)] - struct FlipDimsOps { + struct FlipDimsOps { desc: FlipOperationDescription, + _b: PhantomData, } - impl Operation for FlipDimsOps { + impl Operation for FlipDimsOps { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let axes = &self.desc.axes; @@ -1630,7 +1657,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Flip(desc.clone())), - FlipDimsOps::::new(desc), + FlipDimsOps::::new(desc), ); out @@ -1642,11 +1669,12 @@ impl IntTensorOps for Fusion { times: usize, ) -> IntTensor { #[derive(new)] - struct RepeatOps { + struct RepeatOps { desc: RepeatOperationDescription, + _b: PhantomData, } - impl Operation for RepeatOps { + impl Operation for RepeatOps { fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); @@ -1670,7 +1698,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Repeat(desc.clone())), - RepeatOps::::new(desc), + RepeatOps::::new(desc), ); out diff --git a/crates/burn-fusion/src/ops/module.rs b/crates/burn-fusion/src/ops/module.rs index a3150c97a7..7bbbfbdd1b 100644 --- a/crates/burn-fusion/src/ops/module.rs +++ b/crates/burn-fusion/src/ops/module.rs @@ -11,15 +11,17 @@ use burn_tensor::{ }, repr::*, }; +use std::marker::PhantomData; macro_rules! make_ops { ($name:ident, $desc:ty, $fn:expr) => { #[derive(new)] - struct $name { + struct $name { desc: $desc, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { #[allow(clippy::redundant_closure_call)] $fn(self.desc, handles) @@ -79,7 +81,7 @@ impl ModuleOps> for Fusion { out.client.clone().register( streams, OperationDescription::Module(ModuleOperationDescription::Conv1d(description.clone())), - Conv1dOps::new(description), + Conv1dOps::::new(description), ); out @@ -144,7 +146,7 @@ impl ModuleOps> for Fusion { out.client.register( streams, OperationDescription::Module(ModuleOperationDescription::Conv2d(desc.clone())), - Conv2dOps::new(desc), + Conv2dOps::::new(desc), ); out @@ -203,7 +205,7 @@ impl ModuleOps> for Fusion { out.client.register( streams, OperationDescription::Module(ModuleOperationDescription::ConvTranspose1d(desc.clone())), - ConvTranspose1dOps::new(desc), + ConvTranspose1dOps::::new(desc), ); out @@ -270,7 +272,7 @@ impl ModuleOps> for Fusion { out.client.register( streams, OperationDescription::Module(ModuleOperationDescription::ConvTranspose2d(desc.clone())), - ConvTranspose2dOps::new(desc), + ConvTranspose2dOps::::new(desc), ); out @@ -316,7 +318,7 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream], OperationDescription::Module(ModuleOperationDescription::AvgPool1d(desc.clone())), - AvgPool1dOps::new(desc), + AvgPool1dOps::::new(desc), ); out @@ -366,7 +368,7 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream], OperationDescription::Module(ModuleOperationDescription::AvgPool2d(desc.clone())), - AvgPool2dOps::new(desc), + AvgPool2dOps::::new(desc), ); out @@ -417,7 +419,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AvgPool1dBackward( desc.clone(), )), - AvgPool1dBackwardOps::new(desc), + AvgPool1dBackwardOps::::new(desc), ); out @@ -468,7 +470,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AvgPool2dBackward( desc.clone(), )), - AvgPool2dBackwardOps::new(desc), + AvgPool2dBackwardOps::::new(desc), ); out @@ -515,7 +517,7 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream], OperationDescription::Module(ModuleOperationDescription::MaxPool1d(desc.clone())), - MaxPool1dOps::new(desc), + MaxPool1dOps::::new(desc), ); out @@ -575,7 +577,7 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream], OperationDescription::Module(ModuleOperationDescription::MaxPool2d(desc.clone())), - MaxPool2dOps::new(desc), + MaxPool2dOps::::new(desc), ); out @@ -626,7 +628,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndices( desc.clone(), )), - MaxPool1dWithIndicesOps::new(desc), + MaxPool1dWithIndicesOps::::new(desc), ); MaxPool1dWithIndices::new(out, out_indices) @@ -691,7 +693,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndices( desc.clone(), )), - MaxPool2dWithIndicesOps::new(desc), + MaxPool2dWithIndicesOps::::new(desc), ); MaxPool2dWithIndices::new(out, out_indices) @@ -748,7 +750,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndicesBackward( desc.clone(), )), - MaxPool1dWithIndicesBackwardOps::new(desc), + MaxPool1dWithIndicesBackwardOps::::new(desc), ); MaxPool1dBackward::new(out) @@ -805,7 +807,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndicesBackward( desc.clone(), )), - MaxPool2dWithIndicesBackwardOps::new(desc), + MaxPool2dWithIndicesBackwardOps::::new(desc), ); MaxPool2dBackward::new(out) @@ -837,7 +839,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1d( desc.clone(), )), - AdaptiveAvgPool1dOps::new(desc), + AdaptiveAvgPool1dOps::::new(desc), ); out @@ -872,7 +874,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2d( desc.clone(), )), - AdaptiveAvgPool2dOps::new(desc), + AdaptiveAvgPool2dOps::::new(desc), ); out @@ -909,7 +911,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1dBackward( desc.clone(), )), - AdaptiveAvgPool1dBackwardOps::new(desc), + AdaptiveAvgPool1dBackwardOps::::new(desc), ); out @@ -946,7 +948,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2dBackward( desc.clone(), )), - AdaptiveAvgPool2dBackwardOps::new(desc), + AdaptiveAvgPool2dBackwardOps::::new(desc), ); out @@ -981,7 +983,7 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream], OperationDescription::Module(ModuleOperationDescription::Interpolate(desc.clone())), - InterpolateOps::new(desc), + InterpolateOps::::new(desc), ); out @@ -1022,7 +1024,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::InterpolateBackward( desc.clone(), )), - InterpolateBackwardOps::new(desc), + InterpolateBackwardOps::::new(desc), ); out } diff --git a/crates/burn-fusion/src/ops/unary.rs b/crates/burn-fusion/src/ops/unary.rs index 921f2503da..0120b77948 100644 --- a/crates/burn-fusion/src/ops/unary.rs +++ b/crates/burn-fusion/src/ops/unary.rs @@ -13,11 +13,12 @@ macro_rules! scalar_float_ops { $elem:ty ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription<$elem>, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); @@ -33,11 +34,12 @@ macro_rules! scalar_float_ops { noconvert ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription<$elem>, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs); @@ -57,11 +59,12 @@ macro_rules! scalar_float2int_ops { $elem:ty ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription<$elem>, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs.clone()); @@ -80,11 +83,12 @@ macro_rules! unary_float_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = $ops(input); @@ -99,11 +103,12 @@ macro_rules! unary_float_ops { reduce ) => { #[derive(new)] - struct $name { + struct $name { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = $ops(input); @@ -122,11 +127,12 @@ macro_rules! unary_int_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = $ops(input); @@ -141,11 +147,12 @@ macro_rules! unary_int_ops { reduce ) => { #[derive(new)] - struct $name { + struct $name { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = $ops(input); @@ -164,11 +171,12 @@ macro_rules! scalar_float_cmp_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); @@ -187,11 +195,12 @@ macro_rules! scalar_int_cmp_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); @@ -217,11 +226,12 @@ macro_rules! scalar_int_ops { $elem:ty ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription<$elem>, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); @@ -237,11 +247,12 @@ macro_rules! scalar_int_ops { noconvert ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription<$elem>, + _b: PhantomData, } - impl Operation for $name { + impl Operation for $name { fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs); diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index 3e6b5d911f..1d26532f5c 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -1,6 +1,6 @@ use crate::{ stream::{execution::Operation, MultiStream, StreamId}, - FusionBackend, + FusionBackend, FusionRuntime, }; use burn_tensor::{ ops::{FloatElem, IntElem}, @@ -8,20 +8,17 @@ use burn_tensor::{ }; use std::sync::Arc; -pub struct FusionServer -where - B: FusionBackend, -{ - streams: MultiStream, - pub(crate) handles: HandleContainer, - pub device: B::Device, +pub struct FusionServer { + streams: MultiStream, + pub(crate) handles: HandleContainer, + pub device: R::FusionDevice, } -impl FusionServer +impl FusionServer where - B: FusionBackend, + R: FusionRuntime, { - pub fn new(device: B::Device) -> Self { + pub fn new(device: R::FusionDevice) -> Self { Self { streams: MultiStream::new(device.clone()), handles: HandleContainer::new(), @@ -33,7 +30,7 @@ where &mut self, streams: Vec, desc: OperationDescription, - operation: Box>, + operation: Box>, ) { self.streams .register(streams, desc, operation, &mut self.handles) @@ -47,11 +44,14 @@ where self.handles.create_tensor_uninit() } - pub fn read_float( + pub fn read_float( &mut self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader, D>> { + ) -> burn_tensor::Reader, D>> + where + B: FusionBackend, + { // Make sure all registered operations are executed. // The underlying backend can still be async. self.drain_stream(id); @@ -60,11 +60,14 @@ where B::float_into_data(tensor) } - pub fn read_int( + pub fn read_int( &mut self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader, D>> { + ) -> burn_tensor::Reader, D>> + where + B: FusionBackend, + { // Make sure all registered operations are executed. // The underlying backend can still be async. self.drain_stream(id); @@ -73,11 +76,14 @@ where B::int_into_data(tensor) } - pub fn read_bool( + pub fn read_bool( &mut self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader> { + ) -> burn_tensor::Reader> + where + B: FusionBackend, + { // Make sure all registered operations are executed. // The underlying backend can still be async. self.drain_stream(id); @@ -86,12 +92,15 @@ where B::bool_into_data(tensor) } - pub fn change_server_float( + pub fn change_server_float( &mut self, tensor: &TensorDescription, - device: &B::Device, + device: &R::FusionDevice, server_device: &mut Self, - ) -> Arc { + ) -> Arc + where + B: FusionBackend, + { let tensor = self.handles.get_float_tensor::(tensor); let tensor = B::float_to_device(tensor, device); let id = server_device.create_empty_handle(); @@ -103,12 +112,15 @@ where id } - pub fn change_server_int( + pub fn change_server_int( &mut self, tensor: &TensorDescription, - device: &B::Device, + device: &R::FusionDevice, server_device: &mut Self, - ) -> Arc { + ) -> Arc + where + B: FusionBackend, + { let tensor = self.handles.get_int_tensor::(tensor); let tensor = B::int_to_device(tensor, device); let id = server_device.create_empty_handle(); @@ -120,12 +132,15 @@ where id } - pub fn change_server_bool( + pub fn change_server_bool( &mut self, tensor: &TensorDescription, - device: &B::Device, + device: &R::FusionDevice, server_device: &mut Self, - ) -> Arc { + ) -> Arc + where + B: FusionBackend, + { let tensor = self.handles.get_bool_tensor::(tensor); let tensor = B::bool_to_device(tensor, device); let id = server_device.create_empty_handle(); diff --git a/crates/burn-fusion/src/stream/base.rs b/crates/burn-fusion/src/stream/base.rs index fb3d8f99b3..31ebfb6146 100644 --- a/crates/burn-fusion/src/stream/base.rs +++ b/crates/burn-fusion/src/stream/base.rs @@ -1,18 +1,16 @@ -use burn_tensor::repr::OperationDescription; - -use crate::FusionBackend; - use super::{execution::Operation, OperationConverter, RelativeOps}; +use crate::FusionRuntime; +use burn_tensor::repr::OperationDescription; /// A growing list of [tensor operation descriptions](OperationDescription). -pub struct OperationQueue { +pub struct OperationQueue { pub(crate) global: Vec, pub(crate) relative: Vec, pub(crate) converter: OperationConverter, - pub(crate) operations: Vec>>, + pub(crate) operations: Vec>>, } -impl Default for OperationQueue { +impl Default for OperationQueue { fn default() -> Self { Self::new() } @@ -56,7 +54,7 @@ impl core::fmt::Display for StreamId { } } -impl OperationQueue { +impl OperationQueue { /// Create a new empty queue. pub fn new() -> Self { Self { @@ -72,7 +70,7 @@ impl OperationQueue { /// The new [operation description](OperationDescription) will be converted to a local /// representation that can be reused when the same pattern emerge in different but similar /// scenario, so that the same optimization can be used. - pub fn add(&mut self, global: OperationDescription, operation: Box>) { + pub fn add(&mut self, global: OperationDescription, operation: Box>) { let relative = global.to_relative(&mut self.converter); self.relative.push(relative); self.global.push(global); diff --git a/crates/burn-fusion/src/stream/execution/base.rs b/crates/burn-fusion/src/stream/execution/base.rs index 91f6b322cc..d24733657b 100644 --- a/crates/burn-fusion/src/stream/execution/base.rs +++ b/crates/burn-fusion/src/stream/execution/base.rs @@ -5,7 +5,7 @@ use crate::{ store::{ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy}, OperationQueue, RelativeOps, }, - FusionBackend, Optimization, + FusionRuntime, Optimization, }; /// The mode in which the execution is done. @@ -16,18 +16,18 @@ pub(crate) enum ExecutionMode { } /// General trait to abstract how a single operation is executed. -pub trait Operation: Send + Sync { +pub trait Operation: Send + Sync { /// Execute the operation. - fn execute(self: Box, handles: &mut HandleContainer); + fn execute(self: Box, handles: &mut HandleContainer); } -impl OperationQueue { +impl OperationQueue { /// Execute the queue partially following the execution strategy from the plan. pub(crate) fn execute( &mut self, id: ExecutionPlanId, - handles: &mut HandleContainer, - store: &mut ExecutionPlanStore, + handles: &mut HandleContainer, + store: &mut ExecutionPlanStore, ) { match &mut store.get_mut_unchecked(id).strategy { ExecutionStrategy::Optimization(optimization) => { @@ -39,8 +39,8 @@ impl OperationQueue { fn execute_optimization( &mut self, - handles: &mut HandleContainer, - optimization: &mut B::Optimization, + handles: &mut HandleContainer, + optimization: &mut R::Optimization, ) { let num_drained = optimization.len(); @@ -51,7 +51,7 @@ impl OperationQueue { self.operations.drain(0..num_drained); } - fn execute_operations(&mut self, handles: &mut HandleContainer) { + fn execute_operations(&mut self, handles: &mut HandleContainer) { let num_drained = self.operations.len(); for operation in self.operations.drain(0..num_drained) { @@ -61,7 +61,7 @@ impl OperationQueue { self.drain_stream(num_drained, handles); } - fn drain_stream(&mut self, num_drained: usize, handles: &mut HandleContainer) { + fn drain_stream(&mut self, num_drained: usize, handles: &mut HandleContainer) { self.global[0..num_drained] .iter() .flat_map(|desc| desc.nodes()) diff --git a/crates/burn-fusion/src/stream/multi.rs b/crates/burn-fusion/src/stream/multi.rs index c12442775f..4edcbfe1ce 100644 --- a/crates/burn-fusion/src/stream/multi.rs +++ b/crates/burn-fusion/src/stream/multi.rs @@ -5,18 +5,18 @@ use super::{ store::{ExecutionPlanId, ExecutionPlanStore}, OperationQueue, StreamId, }; -use crate::FusionBackend; +use crate::FusionRuntime; use std::collections::HashMap; /// Keep track of multiple concurrent streams of operations. -pub struct MultiStream { - streams: HashMap>, - optimizations: ExecutionPlanStore, - device: B::Device, +pub struct MultiStream { + streams: HashMap>, + optimizations: ExecutionPlanStore, + device: R::FusionDevice, } -impl MultiStream { - pub(crate) fn new(device: B::Device) -> Self { +impl MultiStream { + pub(crate) fn new(device: R::FusionDevice) -> Self { Self { streams: HashMap::new(), optimizations: ExecutionPlanStore::new(), @@ -29,8 +29,8 @@ impl MultiStream { &mut self, streams: Vec, desc: OperationDescription, - operation: Box>, - handles: &mut HandleContainer, + operation: Box>, + handles: &mut HandleContainer, ) { let id = self.maybe_drain(streams, handles); @@ -65,7 +65,7 @@ impl MultiStream { } /// Drain the streams. - pub fn drain(&mut self, handles: &mut HandleContainer, id: StreamId) { + pub fn drain(&mut self, handles: &mut HandleContainer, id: StreamId) { if let Some(mut stream) = self.streams.remove(&id) { stream.processor.process( Segment::new(&mut stream.queue, handles), @@ -80,7 +80,7 @@ impl MultiStream { fn maybe_drain( &mut self, streams: Vec, - handles: &mut HandleContainer, + handles: &mut HandleContainer, ) -> StreamId { let streams = Self::remove_duplicate(streams); let current = StreamId::current(); @@ -113,7 +113,7 @@ impl MultiStream { output } - fn free_orphans(&self, handles: &mut HandleContainer) { + fn free_orphans(&self, handles: &mut HandleContainer) { let nodes = self .streams .values() @@ -126,31 +126,31 @@ impl MultiStream { } } -struct Stream { - queue: OperationQueue, - processor: Processor, +struct Stream { + queue: OperationQueue, + processor: Processor, } #[derive(new)] -struct Segment<'a, B: FusionBackend> { - queue: &'a mut OperationQueue, - handles: &'a mut HandleContainer, +struct Segment<'a, R: FusionRuntime> { + queue: &'a mut OperationQueue, + handles: &'a mut HandleContainer, } -impl<'i, B: FusionBackend> StreamSegment for Segment<'i, B> { +impl<'i, R: FusionRuntime> StreamSegment for Segment<'i, R> { fn operations(&self) -> &[OperationDescription] { &self.queue.relative } - fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore) { + fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore) { self.queue.execute(id, self.handles, store) } } -impl Stream { - fn new(device: B::Device) -> Self { +impl Stream { + fn new(device: R::FusionDevice) -> Self { Self { - processor: Processor::new(B::optimizations(device)), + processor: Processor::new(R::optimizations(device)), queue: OperationQueue::new(), } } diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index f97f92a532..95ddf6fafe 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -4,7 +4,7 @@ use crate::{ JitBackend, Runtime, }; use burn_compute::client::ComputeClient; -use burn_fusion::{client::MutexFusionClient, FusionBackend}; +use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; use serde::{Deserialize, Serialize}; @@ -26,11 +26,9 @@ pub enum JitOptimizationState { ElementWise(ElementWiseState), } -impl burn_fusion::Optimization> for JitOptimization +impl burn_fusion::Optimization> for JitOptimization where R: Runtime, - F: FloatElement, - I: IntElement, { fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle>) { match self { @@ -102,18 +100,28 @@ impl ReprBackend for JitBackend FusionBackend for JitBackend { +impl FusionRuntime for FusionJitRuntime { type OptimizationState = JitOptimizationState; type Optimization = JitOptimization; - type FusionClient = MutexFusionClient; + type FusionHandle = JitFusionHandle; + type FusionDevice = R::Device; fn optimizations( device: R::Device, ) -> Vec>> { - vec![Box::new(ElementWiseBuilder::::new(device))] + vec![Box::new(ElementWiseBuilder::::new(device))] } } +pub struct FusionJitRuntime { + _b: PhantomData, +} + +impl FusionBackend for JitBackend { + type FusionClient = MutexFusionClient; + type FusionRuntime = FusionJitRuntime; +} + pub fn strides_dyn_rank(shape: &[usize]) -> Vec { let mut strides = vec![0; shape.len()]; From 8cb45efc738b306800103c17eb62b68a374dadb3 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 30 Apr 2024 15:34:15 -0400 Subject: [PATCH 05/25] Fix fusion --- crates/burn-fusion/src/client/base.rs | 5 +- crates/burn-fusion/src/client/mutex.rs | 14 +- crates/burn-fusion/src/ops/boolean.rs | 39 ++- crates/burn-fusion/src/ops/float.rs | 265 ++++++++++++------ crates/burn-fusion/src/ops/int.rs | 226 ++++++++++----- crates/burn-fusion/src/ops/module.rs | 63 +++-- crates/burn-fusion/src/stream/context.rs | 14 +- .../src/stream/execution/policy.rs | 10 +- .../burn-fusion/src/stream/execution/tests.rs | 18 +- crates/burn-fusion/src/stream/store/index.rs | 17 +- crates/burn-fusion/src/tensor.rs | 15 +- crates/burn-jit/src/fusion/base.rs | 2 +- crates/burn-tensor/src/repr/tensor.rs | 4 + crates/burn-tensor/src/tensor/element.rs | 59 +++- 14 files changed, 534 insertions(+), 217 deletions(-) diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index eca898b5c9..1587bcde48 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -6,7 +6,7 @@ use burn_tensor::{ backend::Backend, ops::{FloatElem, IntElem}, repr::{OperationDescription, TensorDescription, TensorId}, - Data, Device, Reader, + DType, Data, Device, Reader, }; /// Define how to interact with the fusion server. @@ -25,13 +25,14 @@ pub trait FusionClient: Send + Sync + Clone { /// Get the current device used by all operations handled by this client. fn device(&self) -> &::Device; /// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it. - fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor; + fn tensor_uninitialized(&self, shape: Vec, dtype: DType) -> FusionTensor; /// Create a tensor with the given handle and shape. fn register_tensor( &self, handle: Handle, shape: Vec, stream: StreamId, + dtype: DType, ) -> FusionTensor; /// Read the values contained by a float tensor. fn read_tensor_float( diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 6e06963bd5..ae360fb73d 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -7,6 +7,7 @@ use burn_tensor::{ backend::Backend, ops::FloatElem, repr::{OperationDescription, TensorDescription, TensorId}, + DType, }; use spin::Mutex; use std::sync::Arc; @@ -59,10 +60,10 @@ where self.server.lock().drain_stream(id); } - fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor { + fn tensor_uninitialized(&self, shape: Vec, dtype: DType) -> FusionTensor { let id = self.server.lock().create_empty_handle(); - FusionTensor::new(id, shape, self.clone(), StreamId::current()) + FusionTensor::new(id, shape, dtype, self.clone(), StreamId::current()) } fn device(&self) -> &::Device { @@ -73,13 +74,14 @@ where handle: Handle, shape: Vec, stream: StreamId, + dtype: DType, ) -> FusionTensor { let mut server = self.server.lock(); let id = server.create_empty_handle(); server.handles.register_handle(*id.as_ref(), handle); core::mem::drop(server); - FusionTensor::new(id, shape, self.clone(), stream) + FusionTensor::new(id, shape, dtype, self.clone(), stream) } fn read_tensor_float( @@ -123,7 +125,7 @@ where core::mem::drop(server_other); core::mem::drop(server_current); - FusionTensor::new(id, tensor.shape, client, StreamId::current()) + FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current()) } fn change_client_int( @@ -142,7 +144,7 @@ where core::mem::drop(server_other); core::mem::drop(server_current); - FusionTensor::new(id, tensor.shape, client, StreamId::current()) + FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current()) } fn change_client_bool( @@ -161,7 +163,7 @@ where core::mem::drop(server_other); core::mem::drop(server_current); - FusionTensor::new(id, tensor.shape, client, StreamId::current()) + FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current()) } fn register_orphan(&self, id: &TensorId) { diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index c156f2d3cb..6eba34acd6 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -1,3 +1,4 @@ +use burn_tensor::{DType, Element}; use std::marker::PhantomData; use crate::{ @@ -28,6 +29,7 @@ impl BoolTensorOps for Fusion { B::bool_tensor_handle(tensor), shape.dims.into(), StreamId::current(), + DType::Bool, ) } @@ -53,6 +55,7 @@ impl BoolTensorOps for Fusion { B::bool_tensor_handle(tensor), shape.dims.into(), StreamId::current(), + DType::Bool, ) } @@ -74,7 +77,9 @@ impl BoolTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -108,7 +113,9 @@ impl BoolTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -171,7 +178,7 @@ impl BoolTensorOps for Fusion { let stream = tensor.stream; let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = ReshapeDescription { input: tensor.into_description(), @@ -216,7 +223,7 @@ impl BoolTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = SliceOperationDescription { tensor: tensor.into_description(), @@ -263,7 +270,7 @@ impl BoolTensorOps for Fusion { let shape: Vec = tensor.shape.clone(); let stream_1 = tensor.stream; let stream_2 = value.stream; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = SliceAssignOperationDescription { tensor: tensor.into_description(), @@ -318,7 +325,7 @@ impl BoolTensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, DType::Bool); let desc = CatOperationDescription { tensors: tensors.into_iter().map(|t| t.into_description()).collect(), @@ -357,7 +364,7 @@ impl BoolTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -389,7 +396,9 @@ impl BoolTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), DType::Bool); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -429,7 +438,7 @@ impl BoolTensorOps for Fusion { shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = SwapDimsDescription { input: tensor.into_description(), @@ -470,7 +479,7 @@ impl BoolTensorOps for Fusion { // Change the shape of the tensor to match the new axes let shape = axes.into_iter().map(|x| tensor.shape[x]).collect(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = PermuteOperationDescription { input: tensor.into_description(), @@ -511,7 +520,9 @@ impl BoolTensorOps for Fusion { let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(shape.dims.into()); + let out = tensor + .client + .tensor_uninitialized(shape.dims.into(), DType::Bool); let desc = ExpandOperationDescription { input: tensor.into_description(), @@ -547,7 +558,9 @@ impl BoolTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), DType::Bool); let desc = FlipOperationDescription { input: tensor.into_description(), @@ -588,7 +601,7 @@ impl BoolTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = times; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = RepeatOperationDescription { tensor: tensor.into_description(), diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index eb478af67b..1e42632def 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -10,7 +10,7 @@ use crate::{ use burn_tensor::{ ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, repr::*, - Data, Device, Distribution, ElementConversion, Reader, Shape, + DType, Data, Device, Distribution, Element, ElementConversion, Reader, Shape, }; use std::{marker::PhantomData, ops::Range}; @@ -27,6 +27,7 @@ impl FloatTensorOps for Fusion { B::float_tensor_handle(tensor), shape.dims.into(), StreamId::current(), + B::FloatElem::dtype(), ) } @@ -53,7 +54,7 @@ impl FloatTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = RandomOperationDescription { out: out.to_description_out(), @@ -86,7 +87,7 @@ impl FloatTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = out.to_description_out(); client.register( @@ -116,7 +117,7 @@ impl FloatTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = out.to_description_out(); client.register( @@ -152,7 +153,7 @@ impl FloatTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = (out.to_description_out(), fill_value.elem::()); client.register( @@ -217,7 +218,9 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -237,7 +240,12 @@ impl FloatTensorOps for Fusion { let stream = StreamId::current(); let tensor = B::float_empty(shape.clone(), device); - client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into(), stream) + client.register_tensor( + B::float_tensor_handle(tensor), + shape.dims.into(), + stream, + B::FloatElem::dtype(), + ) } fn float_add( @@ -248,9 +256,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::FloatElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -274,7 +283,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(AddOps, B::float_add_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -313,7 +324,9 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = ClampOperationDescription { tensor: tensor.into_description(), @@ -338,9 +351,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::FloatElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -363,7 +377,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(SubOps, B::float_sub_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), rhs: rhs.elem(), @@ -389,9 +405,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::FloatElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -414,7 +431,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(MulOps, B::float_mul_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -440,9 +459,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::FloatElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -465,7 +485,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(DivOps, B::float_div_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -490,7 +512,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(ModOps, B::float_remainder_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -521,7 +545,9 @@ impl FloatTensorOps for Fusion { shape[D - 2] = lhs.shape[D - 2]; shape[D - 1] = rhs.shape[D - 1]; - let out = lhs.client.tensor_uninitialized(shape); + let out = lhs + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = BinaryOperationDescription { lhs: lhs.into_description(), rhs: rhs.into_description(), @@ -561,7 +587,9 @@ impl FloatTensorOps for Fusion { shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let mut out = tensor.client.tensor_uninitialized(shape); + let mut out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = SwapDimsDescription { input: tensor.into_description(), @@ -601,7 +629,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ReshapeDescription { input: tensor.into_description(), @@ -640,7 +670,9 @@ impl FloatTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = indices.stream; let shape: Vec = indices.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = GatherOperationDescription { tensor: tensor.into_description(), @@ -685,7 +717,9 @@ impl FloatTensorOps for Fusion { let stream_2 = indices.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ScatterOperationDescription { tensor: tensor.into_description(), @@ -730,7 +764,9 @@ impl FloatTensorOps for Fusion { let stream_2 = indices.stream; let mut shape: Vec = tensor.shape.clone(); shape[dim] = indices.shape[0]; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = SelectOperationDescription { tensor: tensor.into_description(), dim, @@ -774,7 +810,9 @@ impl FloatTensorOps for Fusion { let stream_2 = indices.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = SelectAssignOperationDescription { tensor: tensor.into_description(), @@ -823,7 +861,9 @@ impl FloatTensorOps for Fusion { shape.push(tensor.shape[i]); } - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = SliceOperationDescription { tensor: tensor.into_description(), @@ -870,7 +910,9 @@ impl FloatTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = SliceAssignOperationDescription { tensor: tensor.into_description(), @@ -914,7 +956,9 @@ impl FloatTensorOps for Fusion { let stream_2 = mask.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = MaskWhereOperationDescription { tensor: tensor.into_description(), @@ -958,7 +1002,9 @@ impl FloatTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = mask.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = MaskFillOperationDescription { tensor: tensor.into_description(), value: value.elem(), @@ -984,7 +1030,7 @@ impl FloatTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -1007,7 +1053,9 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(EqualElemOps, B::float_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1035,7 +1083,7 @@ impl FloatTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -1058,7 +1106,9 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(GreaterElemOps, B::float_greater_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1086,7 +1136,7 @@ impl FloatTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -1111,7 +1161,9 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(GreaterEqualElemOps, B::float_greater_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1139,7 +1191,7 @@ impl FloatTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -1162,7 +1214,9 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(LowerElemOps, B::float_lower_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1190,7 +1244,7 @@ impl FloatTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -1215,7 +1269,9 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(LowerEqualElemOps, B::float_lower_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1237,7 +1293,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(SumOps, B::float_sum, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1261,7 +1319,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1281,7 +1341,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(MeanOps, B::float_mean, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1305,7 +1367,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1325,7 +1389,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(ExpOps, B::float_exp); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: lhs.into_description(), @@ -1344,7 +1410,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(LogOps, B::float_log); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1363,7 +1431,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(Log1pOps, B::float_log1p); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1385,7 +1455,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(PowfOps, B::float_powf_scalar, f32); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1405,7 +1477,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(SqrtOps, B::float_sqrt); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1424,7 +1498,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(AbsOps, B::float_abs); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1443,7 +1519,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(CosOps, B::float_cos); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1462,7 +1540,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(SinOps, B::float_sin); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1481,7 +1561,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(TanhOps, B::float_tanh); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1500,7 +1582,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(Recip, B::float_recip); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), @@ -1518,7 +1602,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(TanhOps, B::float_erf); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1569,7 +1655,7 @@ impl FloatTensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = CatOperationDescription { tensors: tensors.into_iter().map(|t| t.into_description()).collect(), @@ -1594,7 +1680,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1634,7 +1722,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = times; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = RepeatOperationDescription { tensor: tensor.into_description(), @@ -1660,7 +1750,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1680,7 +1772,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(MaxOps, B::float_max, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1704,7 +1798,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1744,8 +1840,8 @@ impl FloatTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape.clone(), B::FloatElem::dtype()); + let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReduceDimWithIndicesDescription { tensor: tensor.into_description(), @@ -1768,7 +1864,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(MinOps, B::float_min, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1792,7 +1890,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1832,8 +1932,8 @@ impl FloatTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape.clone(), B::FloatElem::dtype()); + let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReduceDimWithIndicesDescription { tensor: tensor.into_description(), @@ -1860,9 +1960,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::FloatElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -1902,7 +2003,9 @@ impl FloatTensorOps for Fusion { // Change the shape of the tensor to match the new axes let shape = axes.into_iter().map(|x| tensor.shape[x]).collect(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = PermuteOperationDescription { input: tensor.into_description(), @@ -1943,7 +2046,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(shape.dims.into()); + let out = tensor + .client + .tensor_uninitialized(shape.dims.into(), B::FloatElem::dtype()); let desc = ExpandOperationDescription { input: tensor.into_description(), @@ -1979,7 +2084,9 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = FlipOperationDescription { input: tensor.into_description(), diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index efad010d6d..42dcd2b3c9 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -10,7 +10,7 @@ use crate::{ use burn_tensor::{ ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, repr::{self, *}, - Data, Device, Distribution, ElementConversion, Reader, Shape, + DType, Data, Device, Distribution, Element, ElementConversion, Reader, Shape, }; use core::ops::Range; use std::marker::PhantomData; @@ -21,7 +21,12 @@ impl IntTensorOps for Fusion { let tensor = B::int_empty(shape.clone(), device); let stream = StreamId::current(); - client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into(), stream) + client.register_tensor( + B::int_tensor_handle(tensor), + shape.dims.into(), + stream, + B::IntElem::dtype(), + ) } fn int_shape(tensor: &IntTensor) -> Shape { @@ -41,7 +46,12 @@ impl IntTensorOps for Fusion { let shape = B::int_shape(&tensor); let stream = StreamId::current(); - client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into(), stream) + client.register_tensor( + B::int_tensor_handle(tensor), + shape.dims.into(), + stream, + B::IntElem::dtype(), + ) } fn int_device(tensor: &IntTensor) -> Device { @@ -90,7 +100,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReshapeDescription { input: tensor.into_description(), @@ -135,7 +147,9 @@ impl IntTensorOps for Fusion { shape.push(tensor.shape[i]); } - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = SliceOperationDescription { tensor: tensor.into_description(), @@ -182,7 +196,9 @@ impl IntTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = SliceAssignOperationDescription { tensor: tensor.into_description(), ranges: ranges.into(), @@ -225,7 +241,9 @@ impl IntTensorOps for Fusion { let stream_2 = mask.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = MaskWhereOperationDescription { tensor: tensor.into_description(), @@ -267,7 +285,9 @@ impl IntTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = mask.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = MaskFillOperationDescription { tensor: tensor.into_description(), value: value.elem(), @@ -307,7 +327,9 @@ impl IntTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = indices.stream; let shape: Vec = indices.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = GatherOperationDescription { tensor: tensor.into_description(), dim, @@ -351,7 +373,9 @@ impl IntTensorOps for Fusion { let stream_2 = indices.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScatterOperationDescription { tensor: tensor.into_description(), dim, @@ -394,7 +418,9 @@ impl IntTensorOps for Fusion { let stream_2 = indices.stream; let mut shape: Vec = tensor.shape.clone(); shape[dim] = indices.shape[0]; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = SelectOperationDescription { tensor: tensor.into_description(), dim, @@ -438,7 +464,9 @@ impl IntTensorOps for Fusion { let stream_2 = indices.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = SelectAssignOperationDescription { tensor: tensor.into_description(), dim, @@ -490,7 +518,7 @@ impl IntTensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = CatOperationDescription { tensors: tensors.into_iter().map(|t| t.into_description()).collect(), @@ -516,7 +544,7 @@ impl IntTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -539,7 +567,9 @@ impl IntTensorOps for Fusion { scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -565,7 +595,7 @@ impl IntTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -588,7 +618,9 @@ impl IntTensorOps for Fusion { scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -616,7 +648,7 @@ impl IntTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -641,7 +673,9 @@ impl IntTensorOps for Fusion { scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -669,7 +703,7 @@ impl IntTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -692,7 +726,9 @@ impl IntTensorOps for Fusion { scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -718,7 +754,7 @@ impl IntTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -741,7 +777,9 @@ impl IntTensorOps for Fusion { scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -767,9 +805,10 @@ impl IntTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -792,7 +831,9 @@ impl IntTensorOps for Fusion { scalar_int_ops!(AddOps, B::int_add_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -818,9 +859,10 @@ impl IntTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -843,7 +885,9 @@ impl IntTensorOps for Fusion { scalar_int_ops!(SubOps, B::int_sub_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -869,9 +913,10 @@ impl IntTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -894,7 +939,9 @@ impl IntTensorOps for Fusion { scalar_int_ops!(MulOps, B::int_mul_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -920,9 +967,10 @@ impl IntTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -945,7 +993,9 @@ impl IntTensorOps for Fusion { scalar_int_ops!(DivOps, B::int_div_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -970,7 +1020,9 @@ impl IntTensorOps for Fusion { scalar_int_ops!(ModOps, B::int_remainder_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1006,7 +1058,7 @@ impl IntTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = out.to_description_out(); client.register( vec![stream], @@ -1035,7 +1087,7 @@ impl IntTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = out.to_description_out(); client.register( @@ -1051,7 +1103,9 @@ impl IntTensorOps for Fusion { unary_int_ops!(SumOps, B::int_sum, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1072,7 +1126,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1092,7 +1148,9 @@ impl IntTensorOps for Fusion { unary_int_ops!(ProdOps, B::int_prod, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1113,7 +1171,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1133,7 +1193,9 @@ impl IntTensorOps for Fusion { unary_int_ops!(MeanOps, B::int_mean, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1154,7 +1216,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1176,7 +1240,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1198,7 +1264,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1235,7 +1303,9 @@ impl IntTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); let desc = ClampOperationDescription { tensor: tensor.into_description(), min: min.elem(), @@ -1255,7 +1325,9 @@ impl IntTensorOps for Fusion { unary_int_ops!(AbsOps, B::int_abs); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1286,7 +1358,9 @@ impl IntTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), @@ -1324,7 +1398,9 @@ impl IntTensorOps for Fusion { shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = SwapDimsDescription { input: tensor.into_description(), @@ -1345,7 +1421,9 @@ impl IntTensorOps for Fusion { unary_int_ops!(MaxOps, B::int_max, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1366,7 +1444,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1406,8 +1486,8 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape.clone(), B::IntElem::dtype()); + let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReduceDimWithIndicesDescription { tensor: tensor.into_description(), dim, @@ -1429,7 +1509,9 @@ impl IntTensorOps for Fusion { unary_int_ops!(MinOps, B::int_min, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1450,7 +1532,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1490,8 +1574,8 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape.clone(), B::IntElem::dtype()); + let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReduceDimWithIndicesDescription { tensor: tensor.into_description(), dim, @@ -1532,7 +1616,7 @@ impl IntTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = RandomOperationDescription { out: out.to_description_out(), @@ -1571,7 +1655,9 @@ impl IntTensorOps for Fusion { // Change the shape of the tensor to match the new axes let shape = axes.into_iter().map(|x| tensor.shape[x]).collect(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = PermuteOperationDescription { input: tensor.into_description(), @@ -1611,7 +1697,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(shape.dims.into()); + let out = tensor + .client + .tensor_uninitialized(shape.dims.into(), B::IntElem::dtype()); let desc = ExpandOperationDescription { input: tensor.into_description(), @@ -1646,7 +1734,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); let desc = FlipOperationDescription { input: tensor.into_description(), @@ -1687,7 +1777,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = times; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = RepeatOperationDescription { tensor: tensor.into_description(), diff --git a/crates/burn-fusion/src/ops/module.rs b/crates/burn-fusion/src/ops/module.rs index 7bbbfbdd1b..cb0003d5b2 100644 --- a/crates/burn-fusion/src/ops/module.rs +++ b/crates/burn-fusion/src/ops/module.rs @@ -10,6 +10,7 @@ use burn_tensor::{ ModuleOps, }, repr::*, + Element, }; use std::marker::PhantomData; @@ -64,7 +65,7 @@ impl ModuleOps> for Fusion { let stream_2 = weight.stream; let stream_3 = bias.as_ref().map(|b| b.stream); let shape = vec![x.shape[0], weight.shape[0], size]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let description = Conv1dDescription { x: x.into_description(), @@ -129,7 +130,7 @@ impl ModuleOps> for Fusion { let stream_2 = weight.stream; let stream_3 = bias.as_ref().map(|b| b.stream); let shape = vec![x.shape[0], weight.shape[0], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = Conv2dDescription { x: x.into_description(), @@ -188,7 +189,7 @@ impl ModuleOps> for Fusion { let stream_2 = weight.stream; let stream_3 = bias.as_ref().map(|b| b.stream); let shape = vec![x.shape[0], weight.shape[1] * options.groups, size]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ConvTranspose1dDescription { x: x.into_description(), @@ -255,7 +256,7 @@ impl ModuleOps> for Fusion { let stream_2 = weight.stream; let stream_3 = bias.as_ref().map(|b| b.stream); let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ConvTranspose2dDescription { x: x.into_description(), @@ -305,7 +306,7 @@ impl ModuleOps> for Fusion { let stream = x.stream; let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]); let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = AvgPool1dDescription { x: x.into_description(), @@ -355,7 +356,7 @@ impl ModuleOps> for Fusion { let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = AvgPool2dDescription { x: x.into_description(), @@ -403,7 +404,9 @@ impl ModuleOps> for Fusion { let stream_1 = x.stream; let stream_2 = grad.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = AvgPool1dBackwardDescription { x: x.into_description(), @@ -454,7 +457,9 @@ impl ModuleOps> for Fusion { let stream_1 = x.stream; let stream_2 = grad.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = AvgPool2dBackwardDescription { x: x.into_description(), @@ -504,7 +509,7 @@ impl ModuleOps> for Fusion { let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = MaxPool1dDescription { x: x.into_description(), @@ -564,7 +569,7 @@ impl ModuleOps> for Fusion { let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = MaxPool2dDescription { x: x.into_description(), @@ -611,8 +616,10 @@ impl ModuleOps> for Fusion { let stream = x.stream; let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape.clone()); - let out_indices = x.client.tensor_uninitialized(shape); + let out = x + .client + .tensor_uninitialized(shape.clone(), B::FloatElem::dtype()); + let out_indices = x.client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = MaxPool1dWithIndicesDescription { x: x.into_description(), @@ -676,8 +683,10 @@ impl ModuleOps> for Fusion { let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape.clone()); - let out_indices = x.client.tensor_uninitialized(shape); + let out = x + .client + .tensor_uninitialized(shape.clone(), B::FloatElem::dtype()); + let out_indices = x.client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = MaxPool2dWithIndicesDescription { x: x.into_description(), @@ -733,7 +742,9 @@ impl ModuleOps> for Fusion { let stream_1 = x.stream; let stream_2 = output_grad.stream; let stream_3 = indices.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = MaxPool1dWithIndicesBackwardDescription { x: x.into_description(), @@ -790,7 +801,9 @@ impl ModuleOps> for Fusion { let stream_1 = x.stream; let stream_2 = output_grad.stream; let stream_3 = indices.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = MaxPool2dWithIndicesBackwardDescription { x: x.into_description(), @@ -827,7 +840,7 @@ impl ModuleOps> for Fusion { let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], output_size]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = AdaptiveAvgPool1dDescription { x: x.into_description(), @@ -862,7 +875,7 @@ impl ModuleOps> for Fusion { let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = AdaptiveAvgPool2dDescription { x: x.into_description(), @@ -899,7 +912,9 @@ impl ModuleOps> for Fusion { let stream_1 = x.stream; let stream_2 = grad.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = AdaptiveAvgPool1dBackwardDescription { x: x.into_description(), grad: grad.into_description(), @@ -936,7 +951,9 @@ impl ModuleOps> for Fusion { let stream_1 = x.stream; let stream_2 = grad.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = AdaptiveAvgPool2dBackwardDescription { x: x.into_description(), @@ -971,7 +988,7 @@ impl ModuleOps> for Fusion { let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = InterpolateDescription { x: x.into_description(), @@ -1010,7 +1027,9 @@ impl ModuleOps> for Fusion { let stream_1 = x.stream; let stream_2 = grad.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = InterpolateBackwardDescription { x: x.into_description(), diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index b66c89a4b7..422084a782 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -883,6 +883,7 @@ impl RelativeOps for TensorDescription { id: relative_id, shape: relative_shape, status: self.status.clone(), + dtype: self.dtype, }; // We update both mappings. @@ -900,7 +901,10 @@ impl RelativeOps for TensorDescription { #[cfg(test)] mod tests { use super::*; - use burn_tensor::repr::{TensorDescription, TensorId, TensorStatus}; + use burn_tensor::{ + repr::{TensorDescription, TensorId, TensorStatus}, + DType, + }; #[test] fn tensor_description_to_relative() { @@ -908,11 +912,13 @@ mod tests { id: TensorId::new(500), shape: vec![512, 32, 2048], status: TensorStatus::ReadOnly, + dtype: DType::F32, }; let tensor2 = TensorDescription { id: TensorId::new(501), shape: vec![512, 128, 2048], status: TensorStatus::ReadOnly, + dtype: DType::F32, }; let mut converter = OperationConverter::default(); let tensor1_local = tensor1.to_relative(&mut converter); @@ -923,7 +929,8 @@ mod tests { TensorDescription { id: TensorId::new(0), shape: vec![0, 1, 2], - status: TensorStatus::ReadOnly + status: TensorStatus::ReadOnly, + dtype: DType::F32 } ); assert_eq!( @@ -931,7 +938,8 @@ mod tests { TensorDescription { id: TensorId::new(1), shape: vec![0, 3, 2], - status: TensorStatus::ReadOnly + status: TensorStatus::ReadOnly, + dtype: DType::F32 } ); } diff --git a/crates/burn-fusion/src/stream/execution/policy.rs b/crates/burn-fusion/src/stream/execution/policy.rs index e56f7bce62..5424ec5364 100644 --- a/crates/burn-fusion/src/stream/execution/policy.rs +++ b/crates/burn-fusion/src/stream/execution/policy.rs @@ -265,9 +265,12 @@ impl Policy { #[cfg(test)] mod tests { - use burn_tensor::repr::{ - FloatOperationDescription, TensorDescription, TensorId, TensorStatus, - UnaryOperationDescription, + use burn_tensor::{ + repr::{ + FloatOperationDescription, TensorDescription, TensorId, TensorStatus, + UnaryOperationDescription, + }, + DType, }; use super::*; @@ -557,6 +560,7 @@ mod tests { id: TensorId::new(id), shape: vec![32, 32, 1], status: TensorStatus::NotInit, + dtype: DType::F32, }); } diff --git a/crates/burn-fusion/src/stream/execution/tests.rs b/crates/burn-fusion/src/stream/execution/tests.rs index 6755b624b2..0e13546e17 100644 --- a/crates/burn-fusion/src/stream/execution/tests.rs +++ b/crates/burn-fusion/src/stream/execution/tests.rs @@ -6,10 +6,13 @@ //! To test these components effectively, we create mock types for the stream, optimization, //! optimization builder, and stream segment. These mock types aid in comprehensively //! understanding the process of optimizing streams. -use burn_tensor::repr::{ - BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription, - OperationDescription, ScalarOperationDescription, TensorDescription, TensorId, TensorStatus, - UnaryOperationDescription, +use burn_tensor::{ + repr::{ + BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription, + OperationDescription, ScalarOperationDescription, TensorDescription, TensorId, + TensorStatus, UnaryOperationDescription, + }, + DType, }; use crate::{ @@ -523,16 +526,19 @@ fn operation_1() -> OperationDescription { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, rhs: TensorDescription { id: TensorId::new(1), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, out: TensorDescription { id: TensorId::new(2), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, }, )) @@ -546,12 +552,14 @@ fn operation_2() -> OperationDescription { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, rhs: 5.0, out: TensorDescription { id: TensorId::new(2), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, }, )) @@ -564,11 +572,13 @@ fn operation_3() -> OperationDescription { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, out: TensorDescription { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, })) } diff --git a/crates/burn-fusion/src/stream/store/index.rs b/crates/burn-fusion/src/stream/store/index.rs index b1c3111ac9..15bb56341a 100644 --- a/crates/burn-fusion/src/stream/store/index.rs +++ b/crates/burn-fusion/src/stream/store/index.rs @@ -116,9 +116,12 @@ impl ExecutionPlanIndex { #[cfg(test)] mod tests { - use burn_tensor::repr::{ - BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription, - TensorDescription, TensorId, TensorStatus, + use burn_tensor::{ + repr::{ + BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription, + TensorDescription, TensorId, TensorStatus, + }, + DType, }; use super::*; @@ -221,16 +224,19 @@ mod tests { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, rhs: TensorDescription { id: TensorId::new(1), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, out: TensorDescription { id: TensorId::new(2), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, }, )) @@ -243,12 +249,14 @@ mod tests { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, rhs: 5.0, out: TensorDescription { id: TensorId::new(2), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, }, )) @@ -261,16 +269,19 @@ mod tests { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, rhs: TensorDescription { id: TensorId::new(1), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, out: TensorDescription { id: TensorId::new(2), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, }, )) diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 6fad70e723..5204ff0482 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -3,7 +3,7 @@ use burn_tensor::{ backend::Backend, ops::{FloatElem, IntElem}, repr::{TensorDescription, TensorId, TensorStatus}, - Data, Reader, Shape, + DType, Data, Reader, Shape, }; use std::sync::Arc; @@ -16,6 +16,8 @@ pub struct FusionTensor { pub shape: Vec, /// The [fusion client](FusionClient). pub client: C, + /// The datatype of the tensor. + pub dtype: DType, // Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`. // // When a tensor is dropped and is still an orphan, we need to register it as such to avoid @@ -41,11 +43,18 @@ impl core::fmt::Debug for FusionTensor { } impl FusionTensor { - pub(crate) fn new(id: Arc, shape: Vec, client: C, stream: StreamId) -> Self { + pub(crate) fn new( + id: Arc, + shape: Vec, + dtype: DType, + client: C, + stream: StreamId, + ) -> Self { Self { id, shape, client, + dtype, is_orphan: true, stream, } @@ -68,6 +77,7 @@ impl FusionTensor { status: TensorStatus::NotInit, shape: self.shape.clone(), id: *self.id.as_ref(), + dtype: self.dtype, } } @@ -85,6 +95,7 @@ impl FusionTensor { status, shape: shape_out, id: *self.id.as_ref(), + dtype: self.dtype, } } diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 95ddf6fafe..61dd1e0d20 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -109,7 +109,7 @@ impl FusionRuntime for FusionJitRuntime { fn optimizations( device: R::Device, ) -> Vec>> { - vec![Box::new(ElementWiseBuilder::::new(device))] + vec![Box::new(ElementWiseBuilder::::new(device))] } } diff --git a/crates/burn-tensor/src/repr/tensor.rs b/crates/burn-tensor/src/repr/tensor.rs index 525ad9c50b..a68d6b9c2f 100644 --- a/crates/burn-tensor/src/repr/tensor.rs +++ b/crates/burn-tensor/src/repr/tensor.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use crate::DType; + /// The tensor unique identifier. #[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)] pub struct TensorId { @@ -35,6 +37,8 @@ pub struct TensorDescription { pub shape: Vec, /// The [status](TensorStatus) of the tensor when it was used. pub status: TensorStatus, + /// The [type](DType) of the tensor. + pub dtype: DType, } impl TensorId { diff --git a/crates/burn-tensor/src/tensor/element.rs b/crates/burn-tensor/src/tensor/element.rs index d8eaac94b6..e64736d083 100644 --- a/crates/burn-tensor/src/tensor/element.rs +++ b/crates/burn-tensor/src/tensor/element.rs @@ -4,6 +4,7 @@ use crate::Distribution; use half::{bf16, f16}; use num_traits::{identities::Zero, One, ToPrimitive}; use rand::RngCore; +use serde::{Deserialize, Serialize}; /// Element trait for tensor. pub trait Element: @@ -22,6 +23,8 @@ pub trait Element: + Copy + 'static { + /// The dtype of the element. + fn dtype() -> DType; } /// Element conversion trait for tensor. @@ -93,10 +96,15 @@ macro_rules! make_element { ty $type:ident $precision:expr, convert $convert:expr, random $random:expr, - cmp $cmp:expr + cmp $cmp:expr, + dtype $dtype:expr ) => { - impl Element for $type {} + impl Element for $type { + fn dtype() -> $crate::DType { + $dtype + } + } impl ElementConversion for $type { fn from_elem(elem: E) -> Self { @@ -136,56 +144,64 @@ make_element!( ty f64 Precision::Double, convert |elem: &dyn ToPrimitive| elem.to_f64().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &f64, b: &f64| a.total_cmp(b) + cmp |a: &f64, b: &f64| a.total_cmp(b), + dtype DType::F32 ); make_element!( ty f32 Precision::Full, convert |elem: &dyn ToPrimitive| elem.to_f32().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &f32, b: &f32| a.total_cmp(b) + cmp |a: &f32, b: &f32| a.total_cmp(b), + dtype DType::F32 ); make_element!( ty i64 Precision::Double, convert |elem: &dyn ToPrimitive| elem.to_i64().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &i64, b: &i64| Ord::cmp(a, b) + cmp |a: &i64, b: &i64| Ord::cmp(a, b), + dtype DType::I64 ); make_element!( ty i32 Precision::Full, convert |elem: &dyn ToPrimitive| elem.to_i32().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &i32, b: &i32| Ord::cmp(a, b) + cmp |a: &i32, b: &i32| Ord::cmp(a, b), + dtype DType::I32 ); make_element!( ty u32 Precision::Full, convert |elem: &dyn ToPrimitive| elem.to_u32().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &u32, b: &u32| Ord::cmp(a, b) + cmp |a: &u32, b: &u32| Ord::cmp(a, b), + dtype DType::U32 ); make_element!( ty i16 Precision::Half, convert |elem: &dyn ToPrimitive| elem.to_i16().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &i16, b: &i16| Ord::cmp(a, b) + cmp |a: &i16, b: &i16| Ord::cmp(a, b), + dtype DType::I16 ); make_element!( ty i8 Precision::Other, convert |elem: &dyn ToPrimitive| elem.to_i8().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &i8, b: &i8| Ord::cmp(a, b) + cmp |a: &i8, b: &i8| Ord::cmp(a, b), + dtype DType::I8 ); make_element!( ty u8 Precision::Other, convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &u8, b: &u8| Ord::cmp(a, b) + cmp |a: &u8, b: &u8| Ord::cmp(a, b), + dtype DType::U8 ); make_element!( @@ -195,7 +211,8 @@ make_element!( let sample: f32 = distribution.sampler(rng).sample(); f16::from_elem(sample) }, - cmp |a: &f16, b: &f16| a.total_cmp(b) + cmp |a: &f16, b: &f16| a.total_cmp(b), + dtype DType::F16 ); make_element!( ty bf16 Precision::Half, @@ -204,5 +221,23 @@ make_element!( let sample: f32 = distribution.sampler(rng).sample(); bf16::from_elem(sample) }, - cmp |a: &bf16, b: &bf16| a.total_cmp(b) + cmp |a: &bf16, b: &bf16| a.total_cmp(b), + dtype DType::BF16 ); + +#[allow(missing_docs)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub enum DType { + F64, + F32, + F16, + BF16, + I64, + I32, + I16, + I8, + U64, + U32, + U8, + Bool, +} From cb9c515f33154aa6dff3889765b8069d89b22741 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 30 Apr 2024 15:58:03 -0400 Subject: [PATCH 06/25] Update JIT --- .../src/codegen/dialect/gpu/shader.rs | 20 + .../burn-jit/src/fusion/elemwise/builder.rs | 345 +++++++----------- crates/burn-jit/src/fusion/tracing/builder.rs | 6 +- 3 files changed, 166 insertions(+), 205 deletions(-) diff --git a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs index a4651f029a..b8373d7fe5 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs @@ -1,5 +1,6 @@ use super::Scope; use crate::kernel::WORKGROUP_DEFAULT; +use burn_tensor::DType; use serde::{Deserialize, Serialize}; use std::fmt::Display; @@ -46,6 +47,25 @@ impl From for Item { } } +impl From for Elem { + fn from(dtype: DType) -> Self { + match dtype { + DType::F64 => Elem::Float(FloatKind::F64), + DType::F32 => Elem::Float(FloatKind::F32), + DType::F16 => todo!(), + DType::BF16 => todo!(), + DType::I64 => Elem::Int(IntKind::I64), + DType::I32 => Elem::Int(IntKind::I32), + DType::I16 => todo!(), + DType::I8 => todo!(), + DType::U64 => Elem::UInt, + DType::U32 => Elem::UInt, + DType::U8 => todo!(), + DType::Bool => Elem::Bool, + } + } +} + impl Display for Elem { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 0cd346af70..78c3d28074 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -1,42 +1,31 @@ -use core::marker::PhantomData; - use super::{optimization::ElementWise, CompilationPhase}; use crate::{ codegen::dialect::gpu::{ - BinaryOperator, ConditionalAssign, Elem, Operator, Procedure, UnaryOperator, Variable, + BinaryOperator, ConditionalAssign, Operator, Procedure, UnaryOperator, Variable, }, - element::JitElement, fusion::{tracing::TraceBuilder, JitOptimization}, - FloatElement, IntElement, JitBackend, Runtime, + Runtime, }; use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus}; use burn_tensor::{ - ops::{FloatElem, IntElem}, repr::{ BaseOperationDescription, BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription, OperationDescription, ScalarOperationDescription, TensorDescription, UnaryOperationDescription, }, - Device, Element, + Element, }; /// Fused element wise operations that are normally memory bound. -pub(crate) struct ElementWiseBuilder { +pub(crate) struct ElementWiseBuilder { builder: TraceBuilder, current_output_shape: Vec, status: OptimizationStatus, num_added: usize, device: R::Device, - _float_elem: PhantomData, - _int_elem: PhantomData, } -impl OptimizationBuilder> for ElementWiseBuilder -where - R: Runtime, - F: FloatElement, - I: IntElement, -{ +impl OptimizationBuilder> for ElementWiseBuilder { fn register(&mut self, ops: &OperationDescription) { if let OptimizationStatus::Closed = self.status { return; @@ -44,31 +33,31 @@ where match ops { OperationDescription::BaseFloat(ops) => { - if !self.register_base::>>(ops) { + if !self.register_base(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::BaseInt(ops) => { - if !self.register_base::>>(ops) { + if !self.register_base(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::Float(ops) => { - if !self.register_float::>>(ops) { + if !self.register_float(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::NumericFloat(ops) => { - if !self.register_numeric::>, _>(ops) { + if !self.register_numeric::(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::NumericInt(ops) => { - if !self.register_numeric::>, _>(ops) { + if !self.register_numeric::(ops) { self.status = OptimizationStatus::Closed; return; } @@ -119,185 +108,146 @@ where } } -impl ElementWiseBuilder { - pub fn new(device: Device>) -> Self { +impl ElementWiseBuilder { + pub fn new(device: R::Device) -> Self { Self { builder: TraceBuilder::new(), num_added: 0, current_output_shape: Vec::new(), status: OptimizationStatus::Open, device, - _float_elem: PhantomData, - _int_elem: PhantomData, } } - fn register_base(&mut self, ops: &BaseOperationDescription) -> bool { + fn register_base(&mut self, ops: &BaseOperationDescription) -> bool { match ops { - BaseOperationDescription::Equal(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Equal(BinaryOperator { lhs, rhs, out }), - ), + BaseOperationDescription::Equal(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Equal(BinaryOperator { lhs, rhs, out }) + }), _ => false, } } - fn register_float(&mut self, ops: &FloatOperationDescription) -> bool { + fn register_float(&mut self, ops: &FloatOperationDescription) -> bool { match ops { - FloatOperationDescription::Exp(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Exp(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Log(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Log(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Log1p(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { + FloatOperationDescription::Exp(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Exp(UnaryOperator { input, out }) + }), + FloatOperationDescription::Log(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Log(UnaryOperator { input, out }) + }), + FloatOperationDescription::Log1p(desc) => self + .register_unary_ops(desc, |input, out| { Operator::Log1p(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Cos(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Cos(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Sin(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Sin(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::PowfScalar(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Powf(BinaryOperator { lhs, rhs, out }), - ), - FloatOperationDescription::Tanh(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Tanh(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Erf(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Erf(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Recip(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { + }), + FloatOperationDescription::Cos(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Cos(UnaryOperator { input, out }) + }), + FloatOperationDescription::Sin(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Sin(UnaryOperator { input, out }) + }), + FloatOperationDescription::PowfScalar(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Powf(BinaryOperator { lhs, rhs, out }) + }), + FloatOperationDescription::Tanh(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Tanh(UnaryOperator { input, out }) + }), + FloatOperationDescription::Erf(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Erf(UnaryOperator { input, out }) + }), + FloatOperationDescription::Recip(desc) => self + .register_unary_ops(desc, |input, out| { Operator::Recip(UnaryOperator { input, out }) - }) - } + }), _ => false, } } - fn register_numeric( - &mut self, - ops: &NumericOperationDescription, - ) -> bool { + fn register_numeric(&mut self, ops: &NumericOperationDescription) -> bool { match ops { - NumericOperationDescription::Add(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Add(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::AddScalar(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Add(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::Sub(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Sub(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::SubScalar(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Sub(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::Mul(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Mul(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::MulScalar(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Mul(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::Div(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Div(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::DivScalar(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Div(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::Abs(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { + NumericOperationDescription::Add(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Add(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::AddScalar(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Add(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::Sub(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Sub(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::SubScalar(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Sub(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::Mul(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Mul(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::MulScalar(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Mul(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::Div(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Div(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::DivScalar(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Div(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::Abs(desc) => self + .register_unary_ops(desc, |input, out| { Operator::Abs(UnaryOperator { input, out }) - }) - } - NumericOperationDescription::Lower(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Lower(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::LowerElem(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Lower(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::Greater(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Greater(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::GreaterElem(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Greater(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::LowerEqual(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::LowerEqual(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::LowerEqualElem(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::LowerEqual(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::GreaterEqual(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::GreaterEqualElem(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::EqualElem(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Equal(BinaryOperator { lhs, rhs, out }), - ), + }), + NumericOperationDescription::Lower(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Lower(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::LowerElem(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Lower(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::Greater(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Greater(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::GreaterElem(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Greater(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::LowerEqual(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::LowerEqual(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::LowerEqualElem(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::LowerEqual(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::GreaterEqual(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::GreaterEqualElem(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::EqualElem(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Equal(BinaryOperator { lhs, rhs, out }) + }), NumericOperationDescription::MaskWhere(desc) => { if !self.output_is_compatible(&desc.out) { return false; } - let cond = self.builder.input(&desc.mask, Elem::Bool); - let lhs = self.builder.input(&desc.value, E::gpu_elem()); - let rhs = self.builder.input(&desc.tensor, E::gpu_elem()); - let out = self.builder.output(&desc.out, E::gpu_elem()); + let cond = self.builder.input(&desc.mask); + let lhs = self.builder.input(&desc.value); + let rhs = self.builder.input(&desc.tensor); + let out = self.builder.output(&desc.out); self.builder .register_operation(Procedure::ConditionalAssign(ConditionalAssign { @@ -314,10 +264,10 @@ impl ElementWiseBuilder { return false; } - let cond = self.builder.input(&desc.mask, Elem::Bool); - let lhs = self.builder.scalar(&desc.value, E::gpu_elem()); - let rhs = self.builder.input(&desc.tensor, E::gpu_elem()); - let out = self.builder.output(&desc.out, E::gpu_elem()); + let cond = self.builder.input(&desc.mask); + let lhs = self.builder.scalar(&desc.value, desc.out.dtype.into()); + let rhs = self.builder.input(&desc.tensor); + let out = self.builder.output(&desc.out); self.builder .register_operation(Procedure::ConditionalAssign(ConditionalAssign { @@ -334,8 +284,8 @@ impl ElementWiseBuilder { return false; } - let input = Variable::ConstantScalar(1.0, E::gpu_elem()); - let out = self.builder.output(desc, E::gpu_elem()); + let input = Variable::ConstantScalar(1.0, desc.dtype.into()); + let out = self.builder.output(desc); self.builder .register_operation(Operator::Assign(UnaryOperator { input, out })); @@ -347,8 +297,8 @@ impl ElementWiseBuilder { return false; } - let input = Variable::ConstantScalar(0.0, E::gpu_elem()); - let out = self.builder.output(desc, E::gpu_elem()); + let input = Variable::ConstantScalar(0.0, desc.dtype.into()); + let out = self.builder.output(desc); self.builder .register_operation(Operator::Assign(UnaryOperator { input, out })); @@ -360,8 +310,8 @@ impl ElementWiseBuilder { return false; } - let input = self.builder.scalar(elem, E::gpu_elem()); - let out = self.builder.output(desc, E::gpu_elem()); + let input = self.builder.scalar(elem, desc.dtype.into()); + let out = self.builder.output(desc); self.builder .register_operation(Operator::Assign(UnaryOperator { input, out })); @@ -372,12 +322,7 @@ impl ElementWiseBuilder { } } - fn register_binary_ops( - &mut self, - desc: &BinaryOperationDescription, - (elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem), - func: Func, - ) -> bool + fn register_binary_ops(&mut self, desc: &BinaryOperationDescription, func: Func) -> bool where Func: Fn(Variable, Variable, Variable) -> Operator, { @@ -385,21 +330,16 @@ impl ElementWiseBuilder { return false; } - let lhs = self.builder.input(&desc.lhs, elem_lhs); - let rhs = self.builder.input(&desc.rhs, elem_rhs); - let out = self.builder.output(&desc.out, elem_out); + let lhs = self.builder.input(&desc.lhs); + let rhs = self.builder.input(&desc.rhs); + let out = self.builder.output(&desc.out); self.builder.register_operation(func(lhs, rhs, out)); true } - fn register_unary_ops( - &mut self, - desc: &UnaryOperationDescription, - (elem_input, elem_out): (Elem, Elem), - func: Func, - ) -> bool + fn register_unary_ops(&mut self, desc: &UnaryOperationDescription, func: Func) -> bool where Func: Fn(Variable, Variable) -> Operator, { @@ -407,8 +347,8 @@ impl ElementWiseBuilder { return false; } - let input = self.builder.input(&desc.input, elem_input); - let out = self.builder.output(&desc.out, elem_out); + let input = self.builder.input(&desc.input); + let out = self.builder.output(&desc.out); self.builder.register_operation(func(input, out)); @@ -418,7 +358,6 @@ impl ElementWiseBuilder { fn register_scalar_ops( &mut self, desc: &ScalarOperationDescription, - (elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem), func: Func, ) -> bool where @@ -428,9 +367,9 @@ impl ElementWiseBuilder { return false; } - let lhs = self.builder.input(&desc.lhs, elem_lhs); - let rhs = self.builder.scalar(&desc.rhs, elem_rhs); - let out = self.builder.output(&desc.out, elem_out); + let lhs = self.builder.input(&desc.lhs); + let rhs = self.builder.scalar(&desc.rhs, desc.lhs.dtype.into()); + let out = self.builder.output(&desc.out); self.builder.register_operation(func(lhs, rhs, out)); diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index 0377e19e71..49be4e2bd4 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -37,8 +37,9 @@ impl TraceBuilder { } /// Create a variable from an input [tensor description](TensorDescription). - pub fn input(&mut self, tensor: &TensorDescription, elem: gpu::Elem) -> gpu::Variable { + pub fn input(&mut self, tensor: &TensorDescription) -> gpu::Variable { let already_exists = self.tensors.contains_key(&tensor.id); + let elem = tensor.dtype.into(); let variable = match already_exists { false => { @@ -72,7 +73,8 @@ impl TraceBuilder { } /// Create a variable from an output [tensor description](TensorDescription). - pub fn output(&mut self, tensor: &TensorDescription, elem: gpu::Elem) -> gpu::Variable { + pub fn output(&mut self, tensor: &TensorDescription) -> gpu::Variable { + let elem = tensor.dtype.into(); // Update the tensor description to the new version. self.tensors.insert(tensor.id, (tensor.clone(), elem)); From 31a2467e64c72a2eb247fb665132e94a13af4beb Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 30 Apr 2024 19:09:35 -0400 Subject: [PATCH 07/25] WIP but work --- crates/burn-fusion/src/backend.rs | 51 +++++++--- crates/burn-fusion/src/bridge.rs | 123 ++++++++++++++++++++--- crates/burn-fusion/src/client/base.rs | 7 ++ crates/burn-fusion/src/client/mutex.rs | 23 +++++ crates/burn-fusion/src/server.rs | 12 +-- crates/burn-fusion/src/stream/context.rs | 6 ++ crates/burn-jit/src/fusion/base.rs | 22 +++- crates/burn-tensor/src/repr/operation.rs | 10 ++ 8 files changed, 216 insertions(+), 38 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index c17ce0e351..f70c7aa47f 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -1,8 +1,11 @@ use crate::{ - client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge, + client::{FusionClient, MutexFusionClient}, + stream::Context, + FusionClientLocator, FusionTensor, PrecisionBridge, }; use burn_tensor::{ - backend::Backend, + backend::{Backend, DeviceOps}, + ops::FloatTensor, repr::{OperationDescription, ReprBackend}, Device, }; @@ -11,7 +14,7 @@ use std::marker::PhantomData; pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new(); -pub(crate) fn get_client(device: &Device) -> B::FusionClient { +pub(crate) fn get_client(device: &Device) -> MutexFusionClient { CLIENTS.client(device) } @@ -21,20 +24,23 @@ pub struct Fusion { _backend: PhantomData, } -impl Backend for Fusion { +impl Backend for Fusion +where + B: FusionBackend, +{ type Device = B::Device; - type FullPrecisionBridge = PrecisionBridge; + type FullPrecisionBridge = PrecisionBridge; - type FloatTensorPrimitive = FusionTensor; + type FloatTensorPrimitive = FusionTensor>; type FloatElem = B::FloatElem; - type IntTensorPrimitive = FusionTensor; + type IntTensorPrimitive = FusionTensor>; type IntElem = B::IntElem; - type BoolTensorPrimitive = FusionTensor; + type BoolTensorPrimitive = FusionTensor>; fn name() -> String { format!("fusion<{}>", B::name()) @@ -45,10 +51,14 @@ impl Backend for Fusion { } fn sync(device: &Self::Device) { - let client = CLIENTS.client::(&device.clone()); + let client = CLIENTS.client::>(&device.clone()); client.drain(); B::sync(device) } + + fn ad_enabled() -> bool { + false + } } /// The status of a [builder](OptimizationBuilder). @@ -127,9 +137,9 @@ where /// Optimization type for the backend. type Optimization: Optimization; /// Handle - type FusionHandle: Clone; + type FusionHandle: Clone + Send; /// Device - type FusionDevice: Clone; + type FusionDevice: DeviceOps; /// The list of optimizations that will be used to optimize the computational graph. fn optimizations( @@ -139,9 +149,20 @@ where /// Trait that allows an existing [backend](Backend) to specify graph optimizations using /// [operation builder](crate::OptimizationBuilder). -pub trait FusionBackend: ReprBackend { - /// What kind of client should be used. - type FusionClient: FusionClient; +pub trait FusionBackend: + ReprBackend< + Handle = ::FusionHandle, + Device = ::FusionDevice, +> +{ /// The runtime. - type FusionRuntime: FusionRuntime; + type FusionRuntime: FusionRuntime; + + fn cast_float( + tensor: FloatTensor, + dtype: burn_tensor::DType, + ) -> Self::Handle; + + /// Pointer to the full precision fusion backend. + type FullPrecisionBackend: FusionBackend; } diff --git a/crates/burn-fusion/src/bridge.rs b/crates/burn-fusion/src/bridge.rs index 375fd4fb52..8c7467714c 100644 --- a/crates/burn-fusion/src/bridge.rs +++ b/crates/burn-fusion/src/bridge.rs @@ -1,25 +1,122 @@ -use burn_tensor::backend::BackendBridge; - -use crate::{Fusion, FusionBackend}; +use crate::{ + client::FusionClient, stream::execution::Operation, Fusion, FusionBackend, FusionRuntime, +}; +use burn_tensor::{ + backend::BackendBridge, + ops::FloatTensor, + repr::{ + BaseOperationDescription, CastOperationDescription, HandleContainer, OperationDescription, + }, + Element, +}; +use std::marker::PhantomData; #[derive(Debug)] /// Fusion bridge. -pub struct PrecisionBridge; +pub struct PrecisionBridge { + _b: PhantomData, +} -impl BackendBridge> for PrecisionBridge { - type Target = Fusion; +impl BackendBridge> for PrecisionBridge +where + BInput: FusionBackend, + BOutput: FusionBackend, +{ + type Target = Fusion; fn into_target( - tensor: burn_tensor::ops::FloatTensor, D>, + tensor: FloatTensor, D>, _device: Option>, - ) -> burn_tensor::ops::FloatTensor { - tensor + ) -> FloatTensor { + #[derive(new)] + struct Cast { + desc: CastOperationDescription, + _bi: PhantomData, + _bt: PhantomData, + } + + impl Operation for Cast + where + BInput: FusionBackend, + BOutput: FusionBackend, + { + fn execute( + self: Box, + handles: &mut HandleContainer< + ::FusionHandle, + >, + ) { + let input = handles.get_float_tensor::(&self.desc.input); + let output = BInput::cast_float(input, BOutput::FloatElem::dtype()); + + handles.register_handle(self.desc.out.id, output); + } + } + + let stream = tensor.stream; + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), BOutput::FloatElem::dtype()); + + let desc = CastOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + out.client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::Cast(desc.clone())), + Cast::::new(desc), + ); + + out.client.clone().to_backend::(out) } fn from_target( - tensor: burn_tensor::ops::FloatTensor, - _device: Option>>, - ) -> burn_tensor::ops::FloatTensor, D> { - tensor + tensor: FloatTensor, + _device: Option>>, + ) -> FloatTensor, D> { + #[derive(new)] + struct Cast { + desc: CastOperationDescription, + _bi: PhantomData, + _bt: PhantomData, + } + + impl Operation for Cast + where + BInput: FusionBackend, + BOutput: FusionBackend, + { + fn execute( + self: Box, + handles: &mut HandleContainer< + ::FusionHandle, + >, + ) { + let input = handles.get_float_tensor::(&self.desc.input); + let output = BOutput::cast_float(input, BInput::FloatElem::dtype()); + + handles.register_handle(self.desc.out.id, output); + } + } + + let stream = tensor.stream; + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), BOutput::FloatElem::dtype()); + + let desc = CastOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + out.client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::Cast(desc.clone())), + Cast::::new(desc), + ); + + out.client.clone().to_backend::(out) } } diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 1587bcde48..20ae109e46 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -13,6 +13,7 @@ use burn_tensor::{ pub trait FusionClient: Send + Sync + Clone { /// The [fusion backend](FusionBackend) associated type. type FusionBackend: FusionBackend; + type Client: FusionClient; /// Create a new client for the given [device](Backend::Device). fn new(device: Device) -> Self; @@ -75,4 +76,10 @@ pub trait FusionClient: Send + Sync + Clone { ) -> FusionTensor; /// Drop the tensor with the given [tensor id](TensorId). fn register_orphan(&self, id: &TensorId); + fn to_backend(&self, tensor: FusionTensor) -> FusionTensor> + where + B: FusionBackend< + FusionRuntime = ::FusionRuntime, + Device = ::Device, + >; } diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index ae360fb73d..88cb4109bc 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -38,6 +38,7 @@ where B: FusionBackend, { type FusionBackend = B; + type Client = MutexFusionClient; fn new(device: B::Device) -> Self { Self { @@ -169,4 +170,26 @@ where fn register_orphan(&self, id: &TensorId) { self.server.lock().drop_tensor_handle(*id); } + + fn to_backend(&self, tensor: FusionTensor) -> FusionTensor> + where + B1: FusionBackend< + FusionRuntime = ::FusionRuntime, + Device = ::Device, + >, + { + let client = MutexFusionClient { + server: self.server.clone(), + device: self.device.clone(), + }; + + FusionTensor { + id: tensor.id.clone(), + shape: tensor.shape.clone(), + client, + dtype: tensor.dtype, + is_orphan: tensor.is_orphan, + stream: tensor.stream, + } + } } diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index 1d26532f5c..1681d6a811 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -50,7 +50,7 @@ where id: StreamId, ) -> burn_tensor::Reader, D>> where - B: FusionBackend, + B: FusionBackend, { // Make sure all registered operations are executed. // The underlying backend can still be async. @@ -66,7 +66,7 @@ where id: StreamId, ) -> burn_tensor::Reader, D>> where - B: FusionBackend, + B: FusionBackend, { // Make sure all registered operations are executed. // The underlying backend can still be async. @@ -82,7 +82,7 @@ where id: StreamId, ) -> burn_tensor::Reader> where - B: FusionBackend, + B: FusionBackend, { // Make sure all registered operations are executed. // The underlying backend can still be async. @@ -99,7 +99,7 @@ where server_device: &mut Self, ) -> Arc where - B: FusionBackend, + B: FusionBackend, { let tensor = self.handles.get_float_tensor::(tensor); let tensor = B::float_to_device(tensor, device); @@ -119,7 +119,7 @@ where server_device: &mut Self, ) -> Arc where - B: FusionBackend, + B: FusionBackend, { let tensor = self.handles.get_int_tensor::(tensor); let tensor = B::int_to_device(tensor, device); @@ -139,7 +139,7 @@ where server_device: &mut Self, ) -> Arc where - B: FusionBackend, + B: FusionBackend, { let tensor = self.handles.get_bool_tensor::(tensor); let tensor = B::bool_to_device(tensor, device); diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 422084a782..d384289706 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -849,6 +849,12 @@ impl RelativeOps for BaseOperationDescription { out: desc.out.to_relative(converter), }) } + BaseOperationDescription::Cast(desc) => { + BaseOperationDescription::Cast(CastOperationDescription { + input: desc.input.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } } } } diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 61dd1e0d20..944a783660 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,11 +1,11 @@ use super::{ElementWise, ElementWiseState}; use crate::{ - element::JitElement, fusion::ElementWiseBuilder, tensor::JitTensor, FloatElement, IntElement, - JitBackend, Runtime, + element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement, + IntElement, JitBackend, PrecisionBridge, Runtime, }; use burn_compute::client::ComputeClient; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; -use burn_tensor::{repr::ReprBackend, Shape}; +use burn_tensor::{backend::BackendBridge, repr::ReprBackend, Shape}; use core::marker::PhantomData; use serde::{Deserialize, Serialize}; @@ -118,8 +118,22 @@ pub struct FusionJitRuntime { } impl FusionBackend for JitBackend { - type FusionClient = MutexFusionClient; type FusionRuntime = FusionJitRuntime; + + type FullPrecisionBackend = JitBackend; + + fn cast_float( + tensor: burn_tensor::ops::FloatTensor, + dtype: burn_tensor::DType, + ) -> Self::Handle { + match dtype { + burn_tensor::DType::F32 => { + let tensor = kernel::cast::(tensor); + JitFusionHandle::from(tensor) + } + _ => panic!("Unsupported"), + } + } } pub fn strides_dyn_rank(shape: &[usize]) -> Vec { diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 220d8a90cf..12a8330174 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -195,6 +195,8 @@ pub enum BaseOperationDescription { /// Int => [cat](crate::ops::IntTensorOps::int_cat). /// Bool => [cat](crate::ops::BoolTensorOps::bool_cat). Cat(CatOperationDescription), + /// Cast operation, no direct operation and should be supported by fusion backend. + Cast(CastOperationDescription), } /// Numeric operations on int and float tensors. @@ -1005,6 +1007,13 @@ pub struct InterpolateDescription { pub out: TensorDescription, } +#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] +#[allow(missing_docs)] +pub struct CastOperationDescription { + pub input: TensorDescription, + pub out: TensorDescription, +} + impl From for InterpolateMode { fn from(val: InterpolateModeDescription) -> Self { match val { @@ -1102,6 +1111,7 @@ impl BaseOperationDescription { vec![&desc.tensor, &desc.out] } BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(), + BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out], } } } From da5cfa9678d8aaeeb14178cee9a4e31420da0792 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 1 May 2024 07:30:09 -0400 Subject: [PATCH 08/25] Fusion Client --- crates/burn-fusion/src/backend.rs | 24 ++-- crates/burn-fusion/src/bridge.rs | 131 +++++++----------- crates/burn-fusion/src/client/base.rs | 60 ++++---- crates/burn-fusion/src/client/mutex.rs | 95 ++++++------- crates/burn-fusion/src/fusion.rs | 9 +- crates/burn-fusion/src/ops/boolean.rs | 4 +- crates/burn-fusion/src/ops/float.rs | 4 +- crates/burn-fusion/src/ops/int.rs | 10 +- crates/burn-fusion/src/stream/context.rs | 2 +- crates/burn-fusion/src/tensor.rs | 29 ++-- crates/burn-jit/src/fusion/base.rs | 5 +- .../burn-jit/src/fusion/elemwise/builder.rs | 3 + crates/burn-tensor/src/repr/operation.rs | 9 +- 13 files changed, 178 insertions(+), 207 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index f70c7aa47f..ea3833e7ff 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -1,7 +1,5 @@ use crate::{ - client::{FusionClient, MutexFusionClient}, - stream::Context, - FusionClientLocator, FusionTensor, PrecisionBridge, + client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge, }; use burn_tensor::{ backend::{Backend, DeviceOps}, @@ -14,7 +12,7 @@ use std::marker::PhantomData; pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new(); -pub(crate) fn get_client(device: &Device) -> MutexFusionClient { +pub(crate) fn get_client(device: &Device) -> Client { CLIENTS.client(device) } @@ -32,15 +30,15 @@ where type FullPrecisionBridge = PrecisionBridge; - type FloatTensorPrimitive = FusionTensor>; + type FloatTensorPrimitive = FusionTensor>; type FloatElem = B::FloatElem; - type IntTensorPrimitive = FusionTensor>; + type IntTensorPrimitive = FusionTensor>; type IntElem = B::IntElem; - type BoolTensorPrimitive = FusionTensor>; + type BoolTensorPrimitive = FusionTensor>; fn name() -> String { format!("fusion<{}>", B::name()) @@ -51,7 +49,7 @@ where } fn sync(device: &Self::Device) { - let client = CLIENTS.client::>(&device.clone()); + let client = CLIENTS.client::>(&device.clone()); client.drain(); B::sync(device) } @@ -126,6 +124,13 @@ pub trait Optimization: Send { fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self; } +/// Type alias for `::FusionDevice`. +pub type FusionDevice = ::FusionDevice; +/// Type alias for `::FusionHandle`. +pub type FusionHandle = ::FusionHandle; +/// Type alias for `::FusionClient`. +pub type Client = ::FusionClient; + /// Trait that allows an existing [backend](Backend) to specify graph optimizations using /// [operation builder](crate::OptimizationBuilder). pub trait FusionRuntime: Send + Sync @@ -140,6 +145,8 @@ where type FusionHandle: Clone + Send; /// Device type FusionDevice: DeviceOps; + /// The client to be used. + type FusionClient: FusionClient; /// The list of optimizations that will be used to optimize the computational graph. fn optimizations( @@ -158,6 +165,7 @@ pub trait FusionBackend: /// The runtime. type FusionRuntime: FusionRuntime; + /// Cast a float tensor. fn cast_float( tensor: FloatTensor, dtype: burn_tensor::DType, diff --git a/crates/burn-fusion/src/bridge.rs b/crates/burn-fusion/src/bridge.rs index 8c7467714c..ae906fd410 100644 --- a/crates/burn-fusion/src/bridge.rs +++ b/crates/burn-fusion/src/bridge.rs @@ -5,7 +5,7 @@ use burn_tensor::{ backend::BackendBridge, ops::FloatTensor, repr::{ - BaseOperationDescription, CastOperationDescription, HandleContainer, OperationDescription, + BaseOperationDescription, HandleContainer, OperationDescription, UnaryOperationDescription, }, Element, }; @@ -17,106 +17,73 @@ pub struct PrecisionBridge { _b: PhantomData, } -impl BackendBridge> for PrecisionBridge +impl BackendBridge> for PrecisionBridge where BInput: FusionBackend, - BOutput: FusionBackend, + BTarget: FusionBackend, { - type Target = Fusion; + type Target = Fusion; fn into_target( tensor: FloatTensor, D>, _device: Option>, ) -> FloatTensor { - #[derive(new)] - struct Cast { - desc: CastOperationDescription, - _bi: PhantomData, - _bt: PhantomData, - } - - impl Operation for Cast - where - BInput: FusionBackend, - BOutput: FusionBackend, - { - fn execute( - self: Box, - handles: &mut HandleContainer< - ::FusionHandle, - >, - ) { - let input = handles.get_float_tensor::(&self.desc.input); - let output = BInput::cast_float(input, BOutput::FloatElem::dtype()); - - handles.register_handle(self.desc.out.id, output); - } - } - - let stream = tensor.stream; - let out = tensor - .client - .tensor_uninitialized(tensor.shape.clone(), BOutput::FloatElem::dtype()); - - let desc = CastOperationDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }; - - out.client.register( - vec![stream], - OperationDescription::BaseFloat(BaseOperationDescription::Cast(desc.clone())), - Cast::::new(desc), - ); - - out.client.clone().to_backend::(out) + cast::(tensor) } fn from_target( tensor: FloatTensor, _device: Option>>, ) -> FloatTensor, D> { - #[derive(new)] - struct Cast { - desc: CastOperationDescription, - _bi: PhantomData, - _bt: PhantomData, - } + cast::(tensor) + } +} - impl Operation for Cast - where - BInput: FusionBackend, - BOutput: FusionBackend, - { - fn execute( - self: Box, - handles: &mut HandleContainer< - ::FusionHandle, - >, - ) { - let input = handles.get_float_tensor::(&self.desc.input); - let output = BOutput::cast_float(input, BInput::FloatElem::dtype()); +fn cast( + input: FloatTensor, D>, +) -> FloatTensor, D> +where + BInput: FusionBackend, + BTarget: FusionBackend, +{ + #[derive(new)] + struct Cast { + desc: UnaryOperationDescription, + _bi: PhantomData, + _bt: PhantomData, + } - handles.register_handle(self.desc.out.id, output); - } + impl Operation for Cast + where + BInput: FusionBackend, + BTarget: FusionBackend, + { + fn execute( + self: Box, + handles: &mut HandleContainer<::FusionHandle>, + ) { + let input = handles.get_float_tensor::(&self.desc.input); + let output = BInput::cast_float(input, BTarget::FloatElem::dtype()); + + handles.register_handle(self.desc.out.id, output); } + } - let stream = tensor.stream; - let out = tensor - .client - .tensor_uninitialized(tensor.shape.clone(), BOutput::FloatElem::dtype()); + let stream = input.stream; + let out = input + .client + .tensor_uninitialized(input.shape.clone(), BTarget::FloatElem::dtype()); - let desc = CastOperationDescription { - input: tensor.into_description(), - out: out.to_description_out(), - }; + let desc = UnaryOperationDescription { + input: input.into_description(), + out: out.to_description_out(), + }; - out.client.register( - vec![stream], - OperationDescription::BaseFloat(BaseOperationDescription::Cast(desc.clone())), - Cast::::new(desc), - ); + out.client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::Cast(desc.clone())), + Cast::::new(desc), + ); - out.client.clone().to_backend::(out) - } + out } diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 20ae109e46..1bdcec8bad 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -1,85 +1,89 @@ use crate::{ stream::{execution::Operation, StreamId}, - FusionBackend, FusionTensor, Handle, + FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, }; use burn_tensor::{ - backend::Backend, ops::{FloatElem, IntElem}, repr::{OperationDescription, TensorDescription, TensorId}, - DType, Data, Device, Reader, + DType, Data, Reader, }; /// Define how to interact with the fusion server. pub trait FusionClient: Send + Sync + Clone { - /// The [fusion backend](FusionBackend) associated type. - type FusionBackend: FusionBackend; - type Client: FusionClient; + /// The [fusion runtime](FusionRuntime) associated type. + type FusionRuntime: FusionRuntime; /// Create a new client for the given [device](Backend::Device). - fn new(device: Device) -> Self; + fn new(device: FusionDevice) -> Self; /// Register a new [tensor operation description](OperationDescription). fn register(&self, streams: Vec, description: OperationDescription, operation: O) where - O: Operation<::FusionRuntime> + 'static; + O: Operation + 'static; /// Register all lazy computation. fn drain(&self); /// Get the current device used by all operations handled by this client. - fn device(&self) -> &::Device; + fn device(&self) -> &FusionDevice; /// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it. fn tensor_uninitialized(&self, shape: Vec, dtype: DType) -> FusionTensor; /// Create a tensor with the given handle and shape. fn register_tensor( &self, - handle: Handle, + handle: FusionHandle, shape: Vec, stream: StreamId, dtype: DType, ) -> FusionTensor; /// Read the values contained by a float tensor. - fn read_tensor_float( + fn read_tensor_float( &self, tensor: TensorDescription, stream: StreamId, - ) -> Reader, D>>; + ) -> Reader, D>> + where + B: FusionBackend; /// Read the values contained by an int tensor. - fn read_tensor_int( + fn read_tensor_int( &self, tensor: TensorDescription, stream: StreamId, - ) -> Reader, D>>; + ) -> Reader, D>> + where + B: FusionBackend; /// Read the values contained by a bool tensor. - fn read_tensor_bool( + fn read_tensor_bool( &self, tensor: TensorDescription, stream: StreamId, - ) -> Reader>; + ) -> Reader> + where + B: FusionBackend; /// Change the client of the given float tensor. - fn change_client_float( + fn change_client_float( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor; + ) -> FusionTensor + where + B: FusionBackend; /// Change the client of the given int tensor. - fn change_client_int( + fn change_client_int( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor; + ) -> FusionTensor + where + B: FusionBackend; /// Change the client of the given bool tensor. - fn change_client_bool( + fn change_client_bool( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor; + ) -> FusionTensor + where + B: FusionBackend; /// Drop the tensor with the given [tensor id](TensorId). fn register_orphan(&self, id: &TensorId); - fn to_backend(&self, tensor: FusionTensor) -> FusionTensor> - where - B: FusionBackend< - FusionRuntime = ::FusionRuntime, - Device = ::Device, - >; } diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 88cb4109bc..b04d5776ef 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -1,10 +1,9 @@ use super::FusionClient; use crate::{ stream::{execution::Operation, StreamId}, - FusionBackend, FusionServer, FusionTensor, Handle, + FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor, }; use burn_tensor::{ - backend::Backend, ops::FloatElem, repr::{OperationDescription, TensorDescription, TensorId}, DType, @@ -13,17 +12,14 @@ use spin::Mutex; use std::sync::Arc; /// Use a mutex to communicate with the fusion server. -pub struct MutexFusionClient -where - B: FusionBackend, -{ - server: Arc>>, - device: B::Device, +pub struct MutexFusionClient { + server: Arc>>, + device: FusionDevice, } -impl Clone for MutexFusionClient +impl Clone for MutexFusionClient where - B: FusionBackend, + R: FusionRuntime, { fn clone(&self) -> Self { Self { @@ -33,14 +29,13 @@ where } } -impl FusionClient for MutexFusionClient +impl FusionClient for MutexFusionClient where - B: FusionBackend, + R: FusionRuntime, { - type FusionBackend = B; - type Client = MutexFusionClient; + type FusionRuntime = R; - fn new(device: B::Device) -> Self { + fn new(device: FusionDevice) -> Self { Self { device: device.clone(), server: Arc::new(Mutex::new(FusionServer::new(device))), @@ -49,7 +44,7 @@ where fn register(&self, streams: Vec, description: OperationDescription, operation: O) where - O: Operation<::FusionRuntime> + 'static, + O: Operation + 'static, { self.server .lock() @@ -67,12 +62,13 @@ where FusionTensor::new(id, shape, dtype, self.clone(), StreamId::current()) } - fn device(&self) -> &::Device { + fn device(&self) -> &FusionDevice { &self.device } + fn register_tensor( &self, - handle: Handle, + handle: FusionHandle, shape: Vec, stream: StreamId, dtype: DType, @@ -85,37 +81,48 @@ where FusionTensor::new(id, shape, dtype, self.clone(), stream) } - fn read_tensor_float( + fn read_tensor_float( &self, tensor: TensorDescription, stream: StreamId, - ) -> burn_tensor::Reader, D>> { + ) -> burn_tensor::Reader, D>> + where + B: FusionBackend, + { self.server.lock().read_float::(tensor, stream) } - fn read_tensor_int( + fn read_tensor_int( &self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader, D>> + ) -> burn_tensor::Reader, D>> + where + B: FusionBackend, { self.server.lock().read_int::(tensor, id) } - fn read_tensor_bool( + fn read_tensor_bool( &self, tensor: TensorDescription, stream: StreamId, - ) -> burn_tensor::Reader> { + ) -> burn_tensor::Reader> + where + B: FusionBackend, + { self.server.lock().read_bool::(tensor, stream) } - fn change_client_float( + fn change_client_float( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor { + ) -> FusionTensor + where + B: FusionBackend, + { let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); server_current.drain_stream(stream); @@ -129,12 +136,15 @@ where FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current()) } - fn change_client_int( + fn change_client_int( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor { + ) -> FusionTensor + where + B: FusionBackend, + { let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); server_current.drain_stream(stream); @@ -148,12 +158,15 @@ where FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current()) } - fn change_client_bool( + fn change_client_bool( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor { + ) -> FusionTensor + where + B: FusionBackend, + { let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); server_current.drain_stream(stream); @@ -170,26 +183,4 @@ where fn register_orphan(&self, id: &TensorId) { self.server.lock().drop_tensor_handle(*id); } - - fn to_backend(&self, tensor: FusionTensor) -> FusionTensor> - where - B1: FusionBackend< - FusionRuntime = ::FusionRuntime, - Device = ::Device, - >, - { - let client = MutexFusionClient { - server: self.server.clone(), - device: self.device.clone(), - }; - - FusionTensor { - id: tensor.id.clone(), - shape: tensor.shape.clone(), - client, - dtype: tensor.dtype, - is_orphan: tensor.is_orphan, - stream: tensor.stream, - } - } } diff --git a/crates/burn-fusion/src/fusion.rs b/crates/burn-fusion/src/fusion.rs index f224fdaa33..ece74dd656 100644 --- a/crates/burn-fusion/src/fusion.rs +++ b/crates/burn-fusion/src/fusion.rs @@ -1,9 +1,9 @@ use burn_tensor::{ - backend::{Backend, DeviceId, DeviceOps}, + backend::{DeviceId, DeviceOps}, repr::ReprBackend, }; -use crate::client::FusionClient; +use crate::{client::FusionClient, FusionDevice}; use std::{any::Any, collections::HashMap, ops::DerefMut}; @@ -26,10 +26,7 @@ impl FusionClientLocator { /// Get the fusion client for the given device. /// /// Provide the init function to create a new client if it isn't already initialized. - pub fn client( - &self, - device: &::Device, - ) -> C { + pub fn client(&self, device: &FusionDevice) -> C { let device_id = device.id(); let client_id = (core::any::TypeId::of::(), device_id); let mut clients = self.clients.lock(); diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index 6eba34acd6..6901974638 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -40,7 +40,7 @@ impl BoolTensorOps for Fusion { fn bool_into_data( tensor: BoolTensor, ) -> burn_tensor::Reader> { - tensor.bool_into_data() + tensor.bool_into_data::() } fn bool_from_data( @@ -149,7 +149,7 @@ impl BoolTensorOps for Fusion { let client_target = get_client::(&device_target); let client_original = tensor.client.clone(); - client_original.clone().change_client_bool::( + client_original.clone().change_client_bool::( tensor.into_description(), client_target, id, diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 1e42632def..6eb7afde51 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -172,7 +172,7 @@ impl FloatTensorOps for Fusion { fn float_into_data( tensor: FloatTensor, ) -> Reader, D>> { - tensor.into_data() + tensor.into_data::() } fn float_device(tensor: &FloatTensor) -> Device { @@ -194,7 +194,7 @@ impl FloatTensorOps for Fusion { let client_target = get_client::(&device_target); let client_original = tensor.client.clone(); - client_original.clone().change_client_float::( + client_original.clone().change_client_float::( tensor.into_description(), client_target, id, diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 42dcd2b3c9..8152119598 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -34,7 +34,7 @@ impl IntTensorOps for Fusion { } fn int_into_data(tensor: IntTensor) -> Reader, D>> { - tensor.int_into_data() + tensor.int_into_data::() } fn int_from_data( @@ -73,9 +73,11 @@ impl IntTensorOps for Fusion { let client_target = get_client::(&device_target); let client_original = tensor.client.clone(); - client_original - .clone() - .change_client_int::(tensor.into_description(), client_target, id) + client_original.clone().change_client_int::( + tensor.into_description(), + client_target, + id, + ) } fn int_reshape( diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index d384289706..a7d2d0454f 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -850,7 +850,7 @@ impl RelativeOps for BaseOperationDescription { }) } BaseOperationDescription::Cast(desc) => { - BaseOperationDescription::Cast(CastOperationDescription { + BaseOperationDescription::Cast(UnaryOperationDescription { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }) diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 5204ff0482..4ae7c55c63 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -1,6 +1,5 @@ -use crate::{client::FusionClient, stream::StreamId}; +use crate::{client::FusionClient, stream::StreamId, FusionBackend}; use burn_tensor::{ - backend::Backend, ops::{FloatElem, IntElem}, repr::{TensorDescription, TensorId, TensorStatus}, DType, Data, Reader, Shape, @@ -30,11 +29,10 @@ impl core::fmt::Debug for FusionTensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str( format!( - "{{ id: {:?}, shape: {:?}, should_drop: {:?}, backend: {:?}, device: {:?} }}", + "{{ id: {:?}, shape: {:?}, should_drop: {:?}, device: {:?} }}", self.id, self.shape, self.is_orphan, - ::name(), self.client.device().clone(), ) .as_str(), @@ -99,27 +97,34 @@ impl FusionTensor { } } - pub(crate) fn into_data(self) -> Reader, D>> { + pub(crate) fn into_data(self) -> Reader, D>> + where + B: FusionBackend, + { let id = self.stream; self.client .clone() - .read_tensor_float(self.into_description(), id) + .read_tensor_float::(self.into_description(), id) } - pub(crate) fn int_into_data( - self, - ) -> Reader, D>> { + pub(crate) fn int_into_data(self) -> Reader, D>> + where + B: FusionBackend, + { let id = self.stream; self.client .clone() - .read_tensor_int(self.into_description(), id) + .read_tensor_int::(self.into_description(), id) } - pub(crate) fn bool_into_data(self) -> Reader> { + pub(crate) fn bool_into_data(self) -> Reader> + where + B: FusionBackend, + { let id = self.stream; self.client .clone() - .read_tensor_bool(self.into_description(), id) + .read_tensor_bool::(self.into_description(), id) } } diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 944a783660..06f9632593 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,11 +1,11 @@ use super::{ElementWise, ElementWiseState}; use crate::{ element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement, - IntElement, JitBackend, PrecisionBridge, Runtime, + IntElement, JitBackend, Runtime, }; use burn_compute::client::ComputeClient; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; -use burn_tensor::{backend::BackendBridge, repr::ReprBackend, Shape}; +use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; use serde::{Deserialize, Serialize}; @@ -105,6 +105,7 @@ impl FusionRuntime for FusionJitRuntime { type Optimization = JitOptimization; type FusionHandle = JitFusionHandle; type FusionDevice = R::Device; + type FusionClient = MutexFusionClient; fn optimizations( device: R::Device, diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 78c3d28074..48748a1912 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -125,6 +125,9 @@ impl ElementWiseBuilder { .register_binary_ops(desc, |lhs, rhs, out| { Operator::Equal(BinaryOperator { lhs, rhs, out }) }), + BaseOperationDescription::Cast(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Assign(UnaryOperator { input, out }) + }), _ => false, } } diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 12a8330174..22811d667d 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -196,7 +196,7 @@ pub enum BaseOperationDescription { /// Bool => [cat](crate::ops::BoolTensorOps::bool_cat). Cat(CatOperationDescription), /// Cast operation, no direct operation and should be supported by fusion backend. - Cast(CastOperationDescription), + Cast(UnaryOperationDescription), } /// Numeric operations on int and float tensors. @@ -1007,13 +1007,6 @@ pub struct InterpolateDescription { pub out: TensorDescription, } -#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] -#[allow(missing_docs)] -pub struct CastOperationDescription { - pub input: TensorDescription, - pub out: TensorDescription, -} - impl From for InterpolateMode { fn from(val: InterpolateModeDescription) -> Self { match val { From fc92754a3846bc51647e00e78677ec43ab7d1717 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 1 May 2024 07:53:00 -0400 Subject: [PATCH 09/25] Cleanup --- Cargo.lock | 8 -------- crates/burn-fusion/src/client/base.rs | 2 +- crates/burn-tensor/src/repr/handle.rs | 12 ++++++------ 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0e84add9db..b9f92d2321 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3776,14 +3776,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "refactor" -version = "0.14.0" -dependencies = [ - "burn", - "serde", -] - [[package]] name = "regex" version = "1.10.4" diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 1bdcec8bad..a8d54b9ab2 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -13,7 +13,7 @@ pub trait FusionClient: Send + Sync + Clone { /// The [fusion runtime](FusionRuntime) associated type. type FusionRuntime: FusionRuntime; - /// Create a new client for the given [device](Backend::Device). + /// Create a new client for the given [device](FusionRuntime::FusionDevice). fn new(device: FusionDevice) -> Self; /// Register a new [tensor operation description](OperationDescription). fn register(&self, streams: Vec, description: OperationDescription, operation: O) diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index 2a745ec6b2..1fff72f176 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -66,7 +66,7 @@ impl HandleContainer { } } - /// Get the [float tensor](ReprBackend::FloatTensorPrimitive) corresponding to the + /// Get the [float tensor](crate::backend::Backend::FloatTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). pub fn get_float_tensor( &mut self, @@ -81,7 +81,7 @@ impl HandleContainer { ) } - /// Get the [int tensor](ReprBackend::IntTensorPrimitive) corresponding to the + /// Get the [int tensor](crate::backend::Backend::IntTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). pub fn get_int_tensor( &mut self, @@ -96,7 +96,7 @@ impl HandleContainer { ) } - /// Get the [bool tensor](ReprBackend::BoolTensorPrimitive) corresponding to the + /// Get the [bool tensor](crate::backend::Backend::BoolTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). pub fn get_bool_tensor( &mut self, @@ -111,7 +111,7 @@ impl HandleContainer { ) } - /// Register a new [float tensor](ReprBackend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). + /// Register a new [float tensor](crate::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_float_tensor( &mut self, id: &TensorId, @@ -123,7 +123,7 @@ impl HandleContainer { self.handles.insert(*id, Handle::Existing(handle)); } - /// Register a new [int tensor](ReprBackend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). + /// Register a new [int tensor](crate::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_int_tensor( &mut self, id: &TensorId, @@ -135,7 +135,7 @@ impl HandleContainer { self.handles.insert(*id, Handle::Existing(handle)); } - /// Register a new [bool tensor](ReprBackend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). + /// Register a new [bool tensor](crate::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_bool_tensor( &mut self, id: &TensorId, From e5b6313271940547b744d91bc25cafc32a07e817 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 1 May 2024 08:43:54 -0400 Subject: [PATCH 10/25] Cleanup --- crates/burn-fusion/src/backend.rs | 8 +++---- .../src/codegen/dialect/gpu/shader.rs | 10 ++++----- crates/burn-jit/src/element.rs | 22 +++++++++++++++++++ crates/burn-jit/src/fusion/base.rs | 16 +++++++++----- 4 files changed, 42 insertions(+), 14 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index ea3833e7ff..09dc010113 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -141,9 +141,9 @@ where type OptimizationState: Serialize + DeserializeOwned; /// Optimization type for the backend. type Optimization: Optimization; - /// Handle + /// Handle used to store tensor dynamically. type FusionHandle: Clone + Send; - /// Device + /// Device used by the runtime. type FusionDevice: DeviceOps; /// The client to be used. type FusionClient: FusionClient; @@ -162,10 +162,10 @@ pub trait FusionBackend: Device = ::FusionDevice, > { - /// The runtime. + /// The runtime used for this backend. type FusionRuntime: FusionRuntime; - /// Cast a float tensor. + /// Cast a float tensor and returns the resulting handle. fn cast_float( tensor: FloatTensor, dtype: burn_tensor::DType, diff --git a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs index b8aafb2938..b007533e72 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs @@ -54,15 +54,15 @@ impl From for Elem { match dtype { DType::F64 => Elem::Float(FloatKind::F64), DType::F32 => Elem::Float(FloatKind::F32), - DType::F16 => todo!(), - DType::BF16 => todo!(), + DType::F16 => Elem::Float(FloatKind::F16), + DType::BF16 => Elem::Float(FloatKind::BF16), DType::I64 => Elem::Int(IntKind::I64), DType::I32 => Elem::Int(IntKind::I32), - DType::I16 => todo!(), - DType::I8 => todo!(), + DType::I16 => panic!("i16 isn't supported yet."), + DType::I8 => panic!("i8 isn't supported yet."), DType::U64 => Elem::UInt, DType::U32 => Elem::UInt, - DType::U8 => todo!(), + DType::U8 => panic!("u8 isn't supported yet."), DType::Bool => Elem::Bool, } } diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index c5aa160aa8..14a7e60472 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -92,6 +92,27 @@ impl JitElement for f32 { } } +impl JitElement for half::f16 { + fn type_name() -> &'static str { + "f16" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } + fn gpu_elem() -> gpu::Elem { + gpu::Elem::Float(gpu::FloatKind::F16) + } + fn maximum_value() -> Self { + half::f16::MAX + } + fn minimum_value() -> Self { + half::f16::MIN + } +} + impl JitElement for half::bf16 { fn type_name() -> &'static str { "bf16" @@ -114,4 +135,5 @@ impl JitElement for half::bf16 { } impl FloatElement for f32 {} impl FloatElement for half::bf16 {} +impl FloatElement for half::f16 {} impl IntElement for i32 {} diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 06f9632593..fc054aba9e 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -7,6 +7,7 @@ use burn_compute::client::ComputeClient; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; +use half::{bf16, f16}; use serde::{Deserialize, Serialize}; /// Fusion optimization type for JIT. @@ -127,12 +128,17 @@ impl FusionBackend for JitBackend, dtype: burn_tensor::DType, ) -> Self::Handle { + fn cast( + tensor: JitTensor, + ) -> JitFusionHandle { + JitFusionHandle::from(kernel::cast::(tensor)) + } + match dtype { - burn_tensor::DType::F32 => { - let tensor = kernel::cast::(tensor); - JitFusionHandle::from(tensor) - } - _ => panic!("Unsupported"), + burn_tensor::DType::F32 => cast::(tensor), + burn_tensor::DType::F16 => cast::(tensor), + burn_tensor::DType::BF16 => cast::(tensor), + _ => panic!("Casting error: {dtype:?} unsupported."), } } } From 2940a6eefc038c20b5d811e8769f6383fd7fd1c9 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 09:03:37 -0400 Subject: [PATCH 11/25] Try --- crates/burn-fusion/src/backend.rs | 11 ++++------- crates/burn-fusion/src/client/base.rs | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index 09dc010113..1d4d38133c 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -133,10 +133,7 @@ pub type Client = ::FusionClient; /// Trait that allows an existing [backend](Backend) to specify graph optimizations using /// [operation builder](crate::OptimizationBuilder). -pub trait FusionRuntime: Send + Sync -where - Self: Sized, -{ +pub trait FusionRuntime: Send + Sync + Sized { /// The state that can be serialized for an optimization. type OptimizationState: Serialize + DeserializeOwned; /// Optimization type for the backend. @@ -158,9 +155,9 @@ where /// [operation builder](crate::OptimizationBuilder). pub trait FusionBackend: ReprBackend< - Handle = ::FusionHandle, - Device = ::FusionDevice, -> + Handle = ::FusionHandle, + Device = ::FusionDevice, + > + Sized { /// The runtime used for this backend. type FusionRuntime: FusionRuntime; diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index a8d54b9ab2..e1f0ecf370 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -9,7 +9,7 @@ use burn_tensor::{ }; /// Define how to interact with the fusion server. -pub trait FusionClient: Send + Sync + Clone { +pub trait FusionClient: Send + Sync + Clone + Sized { /// The [fusion runtime](FusionRuntime) associated type. type FusionRuntime: FusionRuntime; From 1b418f172627f7642986331412f5fa7e1a2e801e Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 09:15:10 -0400 Subject: [PATCH 12/25] Clippy --- crates/burn-compute/src/compute.rs | 11 +++++++++++ crates/burn-import/src/onnx/coalesce.rs | 4 +++- crates/burn-import/src/onnx/from_onnx.rs | 14 +++++++------- crates/burn-import/src/onnx/ir.rs | 2 +- crates/burn-jit/src/fusion/elemwise/builder.rs | 2 +- 5 files changed, 23 insertions(+), 10 deletions(-) diff --git a/crates/burn-compute/src/compute.rs b/crates/burn-compute/src/compute.rs index d396d3fc7f..9a35f53841 100644 --- a/crates/burn-compute/src/compute.rs +++ b/crates/burn-compute/src/compute.rs @@ -8,6 +8,17 @@ pub struct ComputeRuntime { clients: spin::Mutex>>>, } +impl Default for ComputeRuntime +where + Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, + Server: ComputeServer, + Channel: ComputeChannel, +{ + fn default() -> Self { + Self::new() + } +} + impl ComputeRuntime where Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, diff --git a/crates/burn-import/src/onnx/coalesce.rs b/crates/burn-import/src/onnx/coalesce.rs index ddf102ca5e..c3d5d93d21 100644 --- a/crates/burn-import/src/onnx/coalesce.rs +++ b/crates/burn-import/src/onnx/coalesce.rs @@ -173,5 +173,7 @@ fn convert_and_remove_add_node(bias_node: &Node, current_node: &mut Node) { // Push the bias input and update the output name current_node.inputs.push(bias_input); - current_node.outputs[0].name = bias_node.outputs[0].name.clone(); + current_node.outputs[0] + .name + .clone_from(&bias_node.outputs[0].name); } diff --git a/crates/burn-import/src/onnx/from_onnx.rs b/crates/burn-import/src/onnx/from_onnx.rs index d9579294a5..0a38722aec 100644 --- a/crates/burn-import/src/onnx/from_onnx.rs +++ b/crates/burn-import/src/onnx/from_onnx.rs @@ -322,7 +322,7 @@ impl OnnxGraphBuilder { node.node_type, self.node_name_counter[&node.node_type] ) .to_lowercase(); - node.name = new_name.clone(); + node.name.clone_from(&new_name); } fn check_constants(&mut self, node: &mut Node, i: usize, _graph_io: &mut OnnxGraphIO) { @@ -343,7 +343,7 @@ impl OnnxGraphBuilder { ); if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { // The value comes from Identity inputs - input.value = constant.inputs[0].value.clone(); + input.value.clone_from(&constant.inputs[0].value); input.ty = constant.inputs[0].ty.clone(); } else { let arg = convert_constant_value(constant); @@ -383,7 +383,7 @@ impl OnnxGraphBuilder { if let Some(identity_idx) = self.identity_idx.get(&x.name) { let input_name = &self.nodes[*identity_idx].inputs[0].name; - x.name = input_name.clone(); + x.name.clone_from(input_name); } }); } @@ -454,7 +454,7 @@ fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) { match node.outputs[0].ty { ArgType::Tensor(ref mut tensor_type) => { if let ArgType::Tensor(arg_tensor) = &out_arg.ty { - tensor_type.shape = arg_tensor.shape.clone(); + tensor_type.shape.clone_from(&arg_tensor.shape); let inner = arg_tensor .shape .clone() @@ -497,7 +497,7 @@ fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { for node_input in node.inputs.iter_mut() { if let Some(input_name) = graph_io.get_new_name(&node_input.name) { node_input.passed = true; - node_input.name = input_name.clone(); + node_input.name.clone_from(&input_name); } else { node_input.name = "".to_string(); node_input.passed = false; @@ -507,7 +507,7 @@ fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { if node.node_type == NodeType::Constant || node.node_type == NodeType::Identity { let new_name = format!("{}_out{}", node.name, out_count); graph_io.insert(&node.outputs[0], &new_name); - node.outputs[0].name = new_name.clone(); + node.outputs[0].name.clone_from(&new_name); log::debug!("Found {} constant", new_name); } else { for output in node.outputs.iter_mut() { @@ -517,7 +517,7 @@ fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { graph_io.update_name(output, &new_name); - output.name = new_name.clone(); + output.name.clone_from(&new_name); out_count += 1; } } diff --git a/crates/burn-import/src/onnx/ir.rs b/crates/burn-import/src/onnx/ir.rs index eb229ea7ce..620ffed4ec 100644 --- a/crates/burn-import/src/onnx/ir.rs +++ b/crates/burn-import/src/onnx/ir.rs @@ -29,7 +29,7 @@ impl Argument { /// Copy everything except the name from the other argument pub fn copy_value(&mut self, other_arg: &Argument) { self.ty = other_arg.ty.clone(); - self.value = other_arg.value.clone(); + self.value.clone_from(&other_arg.value); } pub fn from_initializer(initializer: &TensorProto) -> Argument { diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 48748a1912..bea4cbf5fb 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -381,7 +381,7 @@ impl ElementWiseBuilder { fn output_is_compatible(&mut self, out: &TensorDescription) -> bool { if self.current_output_shape.is_empty() { - self.current_output_shape = out.shape.clone(); + self.current_output_shape.clone_from(&out.shape); } else if self.current_output_shape != out.shape { return false; } From 49cde674912a3135c31e4a602c7fa760f1aeabeb Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 09:26:40 -0400 Subject: [PATCH 13/25] Update to Rust 1.68 --- .cargo/config | 2 -- crates/burn-common/src/reader.rs | 4 ++-- crates/burn-core/src/record/serde/ser.rs | 24 ++++++++----------- crates/burn-import/src/burn/codegen.rs | 4 ++-- crates/burn-jit/src/kernel/matmul/base.rs | 9 +++++-- crates/burn-jit/src/kernel/matmul/tiling2d.rs | 5 ---- crates/burn-tensor/src/tensor/data.rs | 5 ++-- examples/custom-training-loop/src/lib.rs | 2 +- .../text-classification/src/data/tokenizer.rs | 1 + .../text-generation/src/data/tokenizer.rs | 1 + 10 files changed, 27 insertions(+), 30 deletions(-) delete mode 100644 .cargo/config diff --git a/.cargo/config b/.cargo/config deleted file mode 100644 index c8ebfa02b7..0000000000 --- a/.cargo/config +++ /dev/null @@ -1,2 +0,0 @@ -[alias] -xtask = "run --manifest-path ./xtask/Cargo.toml --" \ No newline at end of file diff --git a/crates/burn-common/src/reader.rs b/crates/burn-common/src/reader.rs index 91f4492c1a..408b44c116 100644 --- a/crates/burn-common/src/reader.rs +++ b/crates/burn-common/src/reader.rs @@ -100,11 +100,11 @@ impl Reader { } /// Map the current reader to another type. - pub fn map O>(self, mapper: F) -> Reader + pub fn map(self, mapper: F) -> Reader where T: 'static + Send, O: 'static + Send, - F: 'static + Send, + F: FnOnce(T) -> O + 'static + Send, { #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] return Reader::Async(Box::new(MappedReader::new(self, mapper))); diff --git a/crates/burn-core/src/record/serde/ser.rs b/crates/burn-core/src/record/serde/ser.rs index a30082b9d4..9c16d09a7a 100644 --- a/crates/burn-core/src/record/serde/ser.rs +++ b/crates/burn-core/src/record/serde/ser.rs @@ -52,13 +52,13 @@ impl SerializerTrait for Serializer { Ok(self) } - fn serialize_newtype_struct( + fn serialize_newtype_struct( self, _name: &'static str, value: &T, ) -> Result where - T: Serialize, + T: Serialize + ?Sized, { value.serialize(self) } @@ -128,9 +128,9 @@ impl SerializerTrait for Serializer { unimplemented!() } - fn serialize_some(self, value: &T) -> Result + fn serialize_some(self, value: &T) -> Result where - T: Serialize, + T: Serialize + ?Sized, { value.serialize(self) } @@ -152,7 +152,7 @@ impl SerializerTrait for Serializer { unimplemented!() } - fn serialize_newtype_variant( + fn serialize_newtype_variant( self, _name: &'static str, _variant_index: u32, @@ -160,7 +160,7 @@ impl SerializerTrait for Serializer { _value: &T, ) -> Result where - T: Serialize, + T: Serialize + ?Sized, { unimplemented!() } @@ -207,13 +207,9 @@ impl SerializeStruct for Serializer { type Ok = NestedValue; type Error = Error; - fn serialize_field( - &mut self, - key: &'static str, - value: &T, - ) -> Result<(), Self::Error> + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> where - T: Serialize, + T: Serialize + ?Sized, { let serialized_value = value.serialize(Serializer::new())?; @@ -248,9 +244,9 @@ impl SerializeSeq for Serializer { type Ok = NestedValue; type Error = Error; - fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> where - T: Serialize, + T: Serialize + ?Sized, { let serialized_value = value.serialize(Serializer::new())?; diff --git a/crates/burn-import/src/burn/codegen.rs b/crates/burn-import/src/burn/codegen.rs index ed617e8086..734aab2d11 100644 --- a/crates/burn-import/src/burn/codegen.rs +++ b/crates/burn-import/src/burn/codegen.rs @@ -9,10 +9,10 @@ fn convert_primitive(primitive: T) -> TokenStream { value.parse().unwrap() } -fn convert_to_array<'a, I, T: ToTokens>(list: I) -> TokenStream +fn convert_to_array<'a, I, T>(list: I) -> TokenStream where I: Iterator, - T: 'a, + T: ToTokens + 'a, { let mut body = quote! {}; diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 60bb659daa..ae6cfb55aa 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -85,7 +85,6 @@ impl Default for Tiling2dConfig { } /// The strategy to be used when launching a matmul kernel. -#[derive(Default)] pub enum MatmulStrategy { /// A simple kernel will be used with memory coalescing optimization. Simple { @@ -100,11 +99,17 @@ pub enum MatmulStrategy { Tiling2dPadded(Tiling2dConfig), #[cfg(feature = "autotune")] /// Using autotune to chose the best kernel based on runtime information. - #[default] Autotune, } +#[allow(clippy::derivable_impls)] // Necessary otherwise the feature flags dont' work. #[cfg(feature = "autotune")] +impl Default for MatmulStrategy { + fn default() -> Self { + MatmulStrategy::Autotune + } +} + #[cfg(not(feature = "autotune"))] impl Default for MatmulStrategy { fn default() -> Self { diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index 3250fb54c4..2ef3b17d0a 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -20,11 +20,6 @@ use super::{ Tiling2dConfig, }; -#[derive(new, Debug)] -struct MatmulTiling2d { - _elem: PhantomData, -} - #[derive(new, Debug)] struct MatmulTiling2dEagerKernel { config: Tiling2dConfig, diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 35c0d49eeb..f1d38d049d 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -112,9 +112,10 @@ impl Distribution { /// # Returns /// /// The distribution sampler. - pub fn sampler(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> + pub fn sampler(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> where - E: rand::distributions::uniform::SampleUniform, + R: RngCore, + E: Element + rand::distributions::uniform::SampleUniform, Standard: rand::distributions::Distribution, { let kind = match self { diff --git a/examples/custom-training-loop/src/lib.rs b/examples/custom-training-loop/src/lib.rs index 89b9bd80f9..5d01dfadd1 100644 --- a/examples/custom-training-loop/src/lib.rs +++ b/examples/custom-training-loop/src/lib.rs @@ -159,7 +159,7 @@ where #[allow(dead_code)] impl Learner2 { - pub fn step3(&mut self, _batch: MnistBatch) + pub fn step3(&mut self, _batch: MnistBatch) where B: AutodiffBackend, M: AutodiffModule, diff --git a/examples/text-classification/src/data/tokenizer.rs b/examples/text-classification/src/data/tokenizer.rs index 3d1044b365..4b1c8adee3 100644 --- a/examples/text-classification/src/data/tokenizer.rs +++ b/examples/text-classification/src/data/tokenizer.rs @@ -5,6 +5,7 @@ // This trait represents the common interface for all tokenizer types. // The `Send + Sync` bounds are necessary for allowing these operations // to work across thread boundaries. +#[allow(dead_code)] pub trait Tokenizer: Send + Sync { /// Converts a text string into a sequence of tokens. fn encode(&self, value: &str) -> Vec; diff --git a/examples/text-generation/src/data/tokenizer.rs b/examples/text-generation/src/data/tokenizer.rs index cf6fc81bae..53b294bc3f 100644 --- a/examples/text-generation/src/data/tokenizer.rs +++ b/examples/text-generation/src/data/tokenizer.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] pub trait Tokenizer: Send + Sync { fn encode(&self, value: &str, special_tokens: bool) -> Vec; fn decode(&self, tokens: &[usize]) -> String; From 114f94bdaeade47dde90ee2466784e4b7b3cdce5 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 09:34:18 -0400 Subject: [PATCH 14/25] Cargo config --- .cargo/config | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .cargo/config diff --git a/.cargo/config b/.cargo/config new file mode 100644 index 0000000000..c8ebfa02b7 --- /dev/null +++ b/.cargo/config @@ -0,0 +1,2 @@ +[alias] +xtask = "run --manifest-path ./xtask/Cargo.toml --" \ No newline at end of file From 842d8fd031d8b8a515f88710599ebfdb2932b664 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 2 May 2024 11:42:00 -0500 Subject: [PATCH 15/25] Fix precondition violated unsafe cloning --- crates/burn-core/src/record/serde/de.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/burn-core/src/record/serde/de.rs b/crates/burn-core/src/record/serde/de.rs index b429e1cd2a..5a09b3bde2 100644 --- a/crates/burn-core/src/record/serde/de.rs +++ b/crates/burn-core/src/record/serde/de.rs @@ -348,13 +348,13 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer { fn clone_unsafely(thing: &T) -> T { unsafe { // Allocate memory for the clone. - let clone = ptr::null_mut(); - // Correcting pointer usage based on feedback - let clone = ptr::addr_of_mut!(*clone); + let mut clone = std::mem::MaybeUninit::::uninit(); + // Get a mutable pointer to the allocated memory. + let clone_ptr = clone.as_mut_ptr(); // Copy the memory - ptr::copy_nonoverlapping(thing as *const T, clone, 1); - // Transmute the cloned data pointer into an owned instance of T. - ptr::read(clone) + ptr::copy_nonoverlapping(thing as *const T, clone_ptr, 1); + // Assume the cloned data is initialized and convert it to an owned instance of T. + clone.assume_init() } } From e0c15a939627755aa818cc33ba9e3c4b1ff4482f Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 14:41:51 -0400 Subject: [PATCH 16/25] Refactor FusionTensor + FusionClient --- crates/burn-fusion/src/backend.rs | 24 ++++++++---------- crates/burn-fusion/src/client/base.rs | 35 ++++++++++++-------------- crates/burn-fusion/src/client/mutex.rs | 29 +++++++++------------ crates/burn-fusion/src/fusion.rs | 18 ++++++------- crates/burn-fusion/src/tensor.rs | 34 +++++++++++++++++-------- 5 files changed, 70 insertions(+), 70 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index 1d4d38133c..5758470dd6 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -13,7 +13,7 @@ use std::marker::PhantomData; pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new(); pub(crate) fn get_client(device: &Device) -> Client { - CLIENTS.client(device) + CLIENTS.client::(device) } /// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend). @@ -22,23 +22,20 @@ pub struct Fusion { _backend: PhantomData, } -impl Backend for Fusion -where - B: FusionBackend, -{ +impl Backend for Fusion { type Device = B::Device; type FullPrecisionBridge = PrecisionBridge; - type FloatTensorPrimitive = FusionTensor>; + type FloatTensorPrimitive = FusionTensor; type FloatElem = B::FloatElem; - type IntTensorPrimitive = FusionTensor>; + type IntTensorPrimitive = FusionTensor; type IntElem = B::IntElem; - type BoolTensorPrimitive = FusionTensor>; + type BoolTensorPrimitive = FusionTensor; fn name() -> String { format!("fusion<{}>", B::name()) @@ -49,7 +46,7 @@ where } fn sync(device: &Self::Device) { - let client = CLIENTS.client::>(&device.clone()); + let client = CLIENTS.client::(&device.clone()); client.drain(); B::sync(device) } @@ -131,8 +128,7 @@ pub type FusionHandle = ::FusionHandle; /// Type alias for `::FusionClient`. pub type Client = ::FusionClient; -/// Trait that allows an existing [backend](Backend) to specify graph optimizations using -/// [operation builder](crate::OptimizationBuilder). +/// Trait that defines a runtime that will benefits from fused operations. pub trait FusionRuntime: Send + Sync + Sized { /// The state that can be serialized for an optimization. type OptimizationState: Serialize + DeserializeOwned; @@ -143,7 +139,7 @@ pub trait FusionRuntime: Send + Sync + Sized { /// Device used by the runtime. type FusionDevice: DeviceOps; /// The client to be used. - type FusionClient: FusionClient; + type FusionClient: FusionClient; /// The list of optimizations that will be used to optimize the computational graph. fn optimizations( @@ -155,8 +151,8 @@ pub trait FusionRuntime: Send + Sync + Sized { /// [operation builder](crate::OptimizationBuilder). pub trait FusionBackend: ReprBackend< - Handle = ::FusionHandle, - Device = ::FusionDevice, + Handle = FusionHandle, + Device = FusionDevice, > + Sized { /// The runtime used for this backend. diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index e1f0ecf370..6aac0d1353 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -9,30 +9,27 @@ use burn_tensor::{ }; /// Define how to interact with the fusion server. -pub trait FusionClient: Send + Sync + Clone + Sized { - /// The [fusion runtime](FusionRuntime) associated type. - type FusionRuntime: FusionRuntime; - +pub trait FusionClient: Send + Sync + Clone + Sized { /// Create a new client for the given [device](FusionRuntime::FusionDevice). - fn new(device: FusionDevice) -> Self; + fn new(device: FusionDevice) -> Self; /// Register a new [tensor operation description](OperationDescription). fn register(&self, streams: Vec, description: OperationDescription, operation: O) where - O: Operation + 'static; + O: Operation + 'static; /// Register all lazy computation. fn drain(&self); /// Get the current device used by all operations handled by this client. - fn device(&self) -> &FusionDevice; + fn device(&self) -> &FusionDevice; /// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it. - fn tensor_uninitialized(&self, shape: Vec, dtype: DType) -> FusionTensor; + fn tensor_uninitialized(&self, shape: Vec, dtype: DType) -> FusionTensor; /// Create a tensor with the given handle and shape. fn register_tensor( &self, - handle: FusionHandle, + handle: FusionHandle, shape: Vec, stream: StreamId, dtype: DType, - ) -> FusionTensor; + ) -> FusionTensor; /// Read the values contained by a float tensor. fn read_tensor_float( &self, @@ -40,7 +37,7 @@ pub trait FusionClient: Send + Sync + Clone + Sized { stream: StreamId, ) -> Reader, D>> where - B: FusionBackend; + B: FusionBackend; /// Read the values contained by an int tensor. fn read_tensor_int( &self, @@ -48,7 +45,7 @@ pub trait FusionClient: Send + Sync + Clone + Sized { stream: StreamId, ) -> Reader, D>> where - B: FusionBackend; + B: FusionBackend; /// Read the values contained by a bool tensor. fn read_tensor_bool( &self, @@ -56,34 +53,34 @@ pub trait FusionClient: Send + Sync + Clone + Sized { stream: StreamId, ) -> Reader> where - B: FusionBackend; + B: FusionBackend; /// Change the client of the given float tensor. fn change_client_float( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor + ) -> FusionTensor where - B: FusionBackend; + B: FusionBackend; /// Change the client of the given int tensor. fn change_client_int( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor + ) -> FusionTensor where - B: FusionBackend; + B: FusionBackend; /// Change the client of the given bool tensor. fn change_client_bool( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor + ) -> FusionTensor where - B: FusionBackend; + B: FusionBackend; /// Drop the tensor with the given [tensor id](TensorId). fn register_orphan(&self, id: &TensorId); } diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index b04d5776ef..2e31588431 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -29,12 +29,7 @@ where } } -impl FusionClient for MutexFusionClient -where - R: FusionRuntime, -{ - type FusionRuntime = R; - +impl> FusionClient for MutexFusionClient { fn new(device: FusionDevice) -> Self { Self { device: device.clone(), @@ -56,7 +51,7 @@ where self.server.lock().drain_stream(id); } - fn tensor_uninitialized(&self, shape: Vec, dtype: DType) -> FusionTensor { + fn tensor_uninitialized(&self, shape: Vec, dtype: DType) -> FusionTensor { let id = self.server.lock().create_empty_handle(); FusionTensor::new(id, shape, dtype, self.clone(), StreamId::current()) @@ -72,7 +67,7 @@ where shape: Vec, stream: StreamId, dtype: DType, - ) -> FusionTensor { + ) -> FusionTensor { let mut server = self.server.lock(); let id = server.create_empty_handle(); server.handles.register_handle(*id.as_ref(), handle); @@ -87,7 +82,7 @@ where stream: StreamId, ) -> burn_tensor::Reader, D>> where - B: FusionBackend, + B: FusionBackend, { self.server.lock().read_float::(tensor, stream) } @@ -98,7 +93,7 @@ where id: StreamId, ) -> burn_tensor::Reader, D>> where - B: FusionBackend, + B: FusionBackend, { self.server.lock().read_int::(tensor, id) } @@ -109,7 +104,7 @@ where stream: StreamId, ) -> burn_tensor::Reader> where - B: FusionBackend, + B: FusionBackend, { self.server.lock().read_bool::(tensor, stream) } @@ -119,9 +114,9 @@ where tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor + ) -> FusionTensor where - B: FusionBackend, + B: FusionBackend, { let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); @@ -141,9 +136,9 @@ where tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor + ) -> FusionTensor where - B: FusionBackend, + B: FusionBackend, { let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); @@ -163,9 +158,9 @@ where tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor + ) -> FusionTensor where - B: FusionBackend, + B: FusionBackend, { let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); diff --git a/crates/burn-fusion/src/fusion.rs b/crates/burn-fusion/src/fusion.rs index ece74dd656..a99b508d0d 100644 --- a/crates/burn-fusion/src/fusion.rs +++ b/crates/burn-fusion/src/fusion.rs @@ -3,7 +3,7 @@ use burn_tensor::{ repr::ReprBackend, }; -use crate::{client::FusionClient, FusionDevice}; +use crate::{client::FusionClient, Client, FusionDevice, FusionRuntime}; use std::{any::Any, collections::HashMap, ops::DerefMut}; @@ -26,24 +26,24 @@ impl FusionClientLocator { /// Get the fusion client for the given device. /// /// Provide the init function to create a new client if it isn't already initialized. - pub fn client(&self, device: &FusionDevice) -> C { + pub fn client(&self, device: &FusionDevice) -> Client { let device_id = device.id(); - let client_id = (core::any::TypeId::of::(), device_id); + let client_id = (core::any::TypeId::of::(), device_id); let mut clients = self.clients.lock(); if clients.is_none() { - let client = C::new(device.clone()); - Self::register_inner::(client_id, client, &mut clients); + let client = Client::::new(device.clone()); + Self::register_inner::(client_id, client, &mut clients); } match clients.deref_mut() { Some(clients) => match clients.get(&client_id) { Some(client) => { - let client: &C = client.downcast_ref().unwrap(); + let client: &Client = client.downcast_ref().unwrap(); client.clone() } None => { - let client = C::new(device.clone()); + let client = Client::::new(device.clone()); let any = Box::new(client.clone()); clients.insert(client_id, any); client @@ -53,9 +53,9 @@ impl FusionClientLocator { } } - fn register_inner( + fn register_inner( key: Key, - client: C, + client: Client, clients: &mut Option>>, ) { if clients.is_none() { diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 4ae7c55c63..54d3f12f1c 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -1,4 +1,4 @@ -use crate::{client::FusionClient, stream::StreamId, FusionBackend}; +use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime}; use burn_tensor::{ ops::{FloatElem, IntElem}, repr::{TensorDescription, TensorId, TensorStatus}, @@ -7,14 +7,13 @@ use burn_tensor::{ use std::sync::Arc; /// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind. -#[derive(Clone)] -pub struct FusionTensor { +pub struct FusionTensor { /// Tensor id. pub id: Arc, /// The shape of the tensor. pub shape: Vec, /// The [fusion client](FusionClient). - pub client: C, + pub client: Client, /// The datatype of the tensor. pub dtype: DType, // Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`. @@ -25,7 +24,20 @@ pub struct FusionTensor { pub(crate) stream: StreamId, } -impl core::fmt::Debug for FusionTensor { +impl Clone for FusionTensor { + fn clone(&self) -> Self { + Self { + id: self.id.clone(), + shape: self.shape.clone(), + client: self.client.clone(), + dtype: self.dtype, + is_orphan: self.is_orphan, + stream: self.stream, + } + } +} + +impl core::fmt::Debug for FusionTensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str( format!( @@ -40,12 +52,12 @@ impl core::fmt::Debug for FusionTensor { } } -impl FusionTensor { +impl FusionTensor { pub(crate) fn new( id: Arc, shape: Vec, dtype: DType, - client: C, + client: Client, stream: StreamId, ) -> Self { Self { @@ -99,7 +111,7 @@ impl FusionTensor { pub(crate) fn into_data(self) -> Reader, D>> where - B: FusionBackend, + B: FusionBackend, { let id = self.stream; self.client @@ -109,7 +121,7 @@ impl FusionTensor { pub(crate) fn int_into_data(self) -> Reader, D>> where - B: FusionBackend, + B: FusionBackend, { let id = self.stream; self.client @@ -119,7 +131,7 @@ impl FusionTensor { pub(crate) fn bool_into_data(self) -> Reader> where - B: FusionBackend, + B: FusionBackend, { let id = self.stream; self.client @@ -128,7 +140,7 @@ impl FusionTensor { } } -impl Drop for FusionTensor { +impl Drop for FusionTensor { fn drop(&mut self) { if !self.is_orphan { return; From b0fa8b62797fd2717cdaf91925fa85b84af35978 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 14:44:25 -0400 Subject: [PATCH 17/25] Fix element dtype --- crates/burn-tensor/src/tensor/element.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-tensor/src/tensor/element.rs b/crates/burn-tensor/src/tensor/element.rs index e64736d083..02b8f696d5 100644 --- a/crates/burn-tensor/src/tensor/element.rs +++ b/crates/burn-tensor/src/tensor/element.rs @@ -145,7 +145,7 @@ make_element!( convert |elem: &dyn ToPrimitive| elem.to_f64().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &f64, b: &f64| a.total_cmp(b), - dtype DType::F32 + dtype DType::F64 ); make_element!( From ea42a9177f9974e0a98830fd3291d17016e2f323 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 15:07:41 -0400 Subject: [PATCH 18/25] Add recursion limit --- crates/burn-wgpu/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 59a246840d..a5a382c9a7 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -1,3 +1,5 @@ +#![recursion_limit="50"] + #[macro_use] extern crate derive_new; From baf327439302d18667023b592c8f40104b62398b Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 15:13:47 -0400 Subject: [PATCH 19/25] FMT --- crates/burn-wgpu/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index a5a382c9a7..870e6e1083 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -1,4 +1,4 @@ -#![recursion_limit="50"] +#![recursion_limit = "50"] #[macro_use] extern crate derive_new; From dcb0116e7a7d4aab84a89456b985a6f62a5b6ddc Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 15:40:02 -0400 Subject: [PATCH 20/25] TEST CI --- crates/burn-wgpu/src/lib.rs | 2 -- xtask/src/runchecks.rs | 5 +++++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 870e6e1083..59a246840d 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -1,5 +1,3 @@ -#![recursion_limit = "50"] - #[macro_use] extern crate derive_new; diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 850b07c81e..5213b12949 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -327,6 +327,11 @@ fn std_checks() { // Build & test each member in workspace let members = get_workspace_members(WorkspaceMemberType::Crate); for member in members { + // TODO: FOR TEST CI + if member.name != "burn-wgpu" { + continue; + } + if disable_wgpu && member.name == "burn-wgpu" { continue; } From 2ed6000cb655c752eb3e990b31d4b13319e1427b Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 16:08:43 -0400 Subject: [PATCH 21/25] Cleanup some types --- crates/burn-fusion/src/backend.rs | 7 ++----- crates/burn-fusion/src/bridge.rs | 2 +- crates/burn-fusion/src/client/base.rs | 5 ++++- crates/burn-fusion/src/client/mutex.rs | 5 ++++- crates/burn-jit/src/element.rs | 2 -- crates/burn-tensor/src/tensor/element.rs | 4 +--- 6 files changed, 12 insertions(+), 13 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index 5758470dd6..f13059ef68 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -18,7 +18,7 @@ pub(crate) fn get_client(device: &Device) -> Client { +pub struct Fusion { _backend: PhantomData, } @@ -150,10 +150,7 @@ pub trait FusionRuntime: Send + Sync + Sized { /// Trait that allows an existing [backend](Backend) to specify graph optimizations using /// [operation builder](crate::OptimizationBuilder). pub trait FusionBackend: - ReprBackend< - Handle = FusionHandle, - Device = FusionDevice, - > + Sized + ReprBackend, Device = FusionDevice> { /// The runtime used for this backend. type FusionRuntime: FusionRuntime; diff --git a/crates/burn-fusion/src/bridge.rs b/crates/burn-fusion/src/bridge.rs index ae906fd410..490977899c 100644 --- a/crates/burn-fusion/src/bridge.rs +++ b/crates/burn-fusion/src/bridge.rs @@ -20,7 +20,7 @@ pub struct PrecisionBridge { impl BackendBridge> for PrecisionBridge where BInput: FusionBackend, - BTarget: FusionBackend, + BTarget: FusionBackend, { type Target = Fusion; diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 6aac0d1353..9d5f945189 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -9,7 +9,10 @@ use burn_tensor::{ }; /// Define how to interact with the fusion server. -pub trait FusionClient: Send + Sync + Clone + Sized { +pub trait FusionClient: Send + Sync + Clone + Sized +where + R: FusionRuntime, +{ /// Create a new client for the given [device](FusionRuntime::FusionDevice). fn new(device: FusionDevice) -> Self; /// Register a new [tensor operation description](OperationDescription). diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 2e31588431..03159642da 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -29,7 +29,10 @@ where } } -impl> FusionClient for MutexFusionClient { +impl FusionClient for MutexFusionClient +where + R: FusionRuntime, +{ fn new(device: FusionDevice) -> Self { Self { device: device.clone(), diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index 14a7e60472..693e3db2d5 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -4,8 +4,6 @@ use burn_tensor::Element; /// The base element trait for the jit backend. pub trait JitElement: burn_tensor::Element + core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod -where - Self: Sized, { /// TODO: Remove when all wgsl static kernels are migrated. fn type_name() -> &'static str; diff --git a/crates/burn-tensor/src/tensor/element.rs b/crates/burn-tensor/src/tensor/element.rs index 02b8f696d5..7a2c99a9bc 100644 --- a/crates/burn-tensor/src/tensor/element.rs +++ b/crates/burn-tensor/src/tensor/element.rs @@ -56,9 +56,7 @@ pub trait ElementRandom { /// # Returns /// /// The random value. - fn random(distribution: Distribution, rng: &mut R) -> Self - where - Self: Sized; + fn random(distribution: Distribution, rng: &mut R) -> Self; } /// Element ordering trait. From f8e25ec1047518ef6481e4f0b385c2a37d68f8e0 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 16:29:40 -0400 Subject: [PATCH 22/25] Hardcode the client --- crates/burn-fusion/src/backend.rs | 8 ++++---- crates/burn-fusion/src/client/base.rs | 2 +- crates/burn-fusion/src/client/mutex.rs | 2 +- crates/burn-fusion/src/ops/binary.rs | 2 +- crates/burn-jit/src/fusion/base.rs | 3 +-- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index f13059ef68..b62fa74d00 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -1,5 +1,7 @@ use crate::{ - client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge, + client::{FusionClient, MutexFusionClient}, + stream::Context, + FusionClientLocator, FusionTensor, PrecisionBridge, }; use burn_tensor::{ backend::{Backend, DeviceOps}, @@ -126,7 +128,7 @@ pub type FusionDevice = ::FusionDevice; /// Type alias for `::FusionHandle`. pub type FusionHandle = ::FusionHandle; /// Type alias for `::FusionClient`. -pub type Client = ::FusionClient; +pub type Client = MutexFusionClient; /// Trait that defines a runtime that will benefits from fused operations. pub trait FusionRuntime: Send + Sync + Sized { @@ -138,8 +140,6 @@ pub trait FusionRuntime: Send + Sync + Sized { type FusionHandle: Clone + Send; /// Device used by the runtime. type FusionDevice: DeviceOps; - /// The client to be used. - type FusionClient: FusionClient; /// The list of optimizations that will be used to optimize the computational graph. fn optimizations( diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 9d5f945189..c9674c8d14 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -11,7 +11,7 @@ use burn_tensor::{ /// Define how to interact with the fusion server. pub trait FusionClient: Send + Sync + Clone + Sized where - R: FusionRuntime, + R: FusionRuntime, { /// Create a new client for the given [device](FusionRuntime::FusionDevice). fn new(device: FusionDevice) -> Self; diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 03159642da..5b73e69e96 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -31,7 +31,7 @@ where impl FusionClient for MutexFusionClient where - R: FusionRuntime, + R: FusionRuntime, { fn new(device: FusionDevice) -> Self { Self { diff --git a/crates/burn-fusion/src/ops/binary.rs b/crates/burn-fusion/src/ops/binary.rs index caa6093b61..a7ed09ab28 100644 --- a/crates/burn-fusion/src/ops/binary.rs +++ b/crates/burn-fusion/src/ops/binary.rs @@ -12,7 +12,7 @@ macro_rules! binary_float_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer<::Handle>) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let rhs = handles.get_float_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index fc054aba9e..443755a9e6 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -4,7 +4,7 @@ use crate::{ IntElement, JitBackend, Runtime, }; use burn_compute::client::ComputeClient; -use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; +use burn_fusion::{FusionBackend, FusionRuntime}; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; use half::{bf16, f16}; @@ -106,7 +106,6 @@ impl FusionRuntime for FusionJitRuntime { type Optimization = JitOptimization; type FusionHandle = JitFusionHandle; type FusionDevice = R::Device; - type FusionClient = MutexFusionClient; fn optimizations( device: R::Device, From ef3215e5cb51068a23a567fb1343b11c6c3c4f59 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 16:39:59 -0400 Subject: [PATCH 23/25] Put backtrace --- xtask/src/runchecks.rs | 2 +- xtask/src/utils/cargo.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 5213b12949..2bc5d2a96d 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -344,7 +344,7 @@ fn std_checks() { } group!("Checks: {}", member.name); - cargo_build(Params::from(["-p", &member.name])); + // cargo_build(Params::from(["-p", &member.name])); cargo_test(Params::from(["-p", &member.name])); endgroup!(); } diff --git a/xtask/src/utils/cargo.rs b/xtask/src/utils/cargo.rs index 0480e6f1c6..aa835226e5 100644 --- a/xtask/src/utils/cargo.rs +++ b/xtask/src/utils/cargo.rs @@ -25,6 +25,7 @@ pub(crate) fn run_cargo_with_path>( let mut cargo = Command::new("cargo"); cargo .env("CARGO_INCREMENTAL", "0") + .env("RUST_BACKTRACE", "full") .envs(&envs) .arg(command) .args(¶ms.params) From 10d4800a220107b918ada8443b1e9c9b6ccdd65a Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 16:49:12 -0400 Subject: [PATCH 24/25] Debug --- xtask/src/runchecks.rs | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 2bc5d2a96d..baae0412dc 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -108,7 +108,7 @@ fn cargo_test(params: Params) { // Run cargo test run_cargo( "test", - params + "--color=always" + "--" + "--color=always", + params + "--color=always" + "--verbose" + "--" + "--color=always", HashMap::new(), "Failed to run cargo test", ); @@ -305,24 +305,24 @@ fn std_checks() { let disable_wgpu = std::env::var("DISABLE_WGPU").is_ok(); // Check format - cargo_fmt(); + // cargo_fmt(); // Check clippy lints - cargo_clippy(); + // cargo_clippy(); // Produce documentation for each workspace member - group!("Docs: crates"); - let mut params = Params::from(["--workspace", "--no-deps"]); - // Exclude burn-cuda on all platforms - params.params.push("--exclude".to_string()); - params.params.push("burn-cuda".to_string()); - cargo_doc(params); - endgroup!(); + // group!("Docs: crates"); + // let mut params = Params::from(["--workspace", "--no-deps"]); + // // Exclude burn-cuda on all platforms + // params.params.push("--exclude".to_string()); + // params.params.push("burn-cuda".to_string()); + // cargo_doc(params); + // endgroup!(); // Setup code coverage - if is_coverage { - setup_coverage(); - } + // if is_coverage { + // setup_coverage(); + // } // Build & test each member in workspace let members = get_workspace_members(WorkspaceMemberType::Crate); From 97a2c4f05d1d12d68ba9f346dbbc33db66ee28cc Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 2 May 2024 17:30:08 -0400 Subject: [PATCH 25/25] Fix the recursive type resolution --- crates/burn-fusion/src/backend.rs | 10 ++++---- crates/burn-fusion/src/bridge.rs | 33 +++++++++++++++----------- crates/burn-fusion/src/client/base.rs | 2 +- crates/burn-fusion/src/client/mutex.rs | 2 +- crates/burn-jit/src/fusion/base.rs | 4 +++- xtask/src/runchecks.rs | 33 +++++++++++--------------- xtask/src/utils/cargo.rs | 1 - 7 files changed, 43 insertions(+), 42 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index b62fa74d00..e5b8cb1a66 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -1,7 +1,5 @@ use crate::{ - client::{FusionClient, MutexFusionClient}, - stream::Context, - FusionClientLocator, FusionTensor, PrecisionBridge, + client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge, }; use burn_tensor::{ backend::{Backend, DeviceOps}, @@ -128,10 +126,10 @@ pub type FusionDevice = ::FusionDevice; /// Type alias for `::FusionHandle`. pub type FusionHandle = ::FusionHandle; /// Type alias for `::FusionClient`. -pub type Client = MutexFusionClient; +pub type Client = ::FusionClient; /// Trait that defines a runtime that will benefits from fused operations. -pub trait FusionRuntime: Send + Sync + Sized { +pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug { /// The state that can be serialized for an optimization. type OptimizationState: Serialize + DeserializeOwned; /// Optimization type for the backend. @@ -140,6 +138,8 @@ pub trait FusionRuntime: Send + Sync + Sized { type FusionHandle: Clone + Send; /// Device used by the runtime. type FusionDevice: DeviceOps; + /// The client to interact with the runtime. + type FusionClient: FusionClient; /// The list of optimizations that will be used to optimize the computational graph. fn optimizations( diff --git a/crates/burn-fusion/src/bridge.rs b/crates/burn-fusion/src/bridge.rs index 490977899c..82888075a6 100644 --- a/crates/burn-fusion/src/bridge.rs +++ b/crates/burn-fusion/src/bridge.rs @@ -14,13 +14,14 @@ use std::marker::PhantomData; #[derive(Debug)] /// Fusion bridge. pub struct PrecisionBridge { - _b: PhantomData, + _backend: PhantomData, } -impl BackendBridge> for PrecisionBridge +impl BackendBridge> for PrecisionBridge where - BInput: FusionBackend, - BTarget: FusionBackend, + BInput: FusionBackend, + BTarget: FusionBackend, + R: FusionRuntime + 'static, { type Target = Fusion; @@ -28,35 +29,39 @@ where tensor: FloatTensor, D>, _device: Option>, ) -> FloatTensor { - cast::(tensor) + cast::(tensor) } fn from_target( tensor: FloatTensor, _device: Option>>, ) -> FloatTensor, D> { - cast::(tensor) + cast::(tensor) } } -fn cast( +fn cast( input: FloatTensor, D>, ) -> FloatTensor, D> where - BInput: FusionBackend, - BTarget: FusionBackend, + BInput: FusionBackend, + BTarget: FusionBackend, + R: FusionRuntime + 'static, { #[derive(new)] - struct Cast { + struct Cast { desc: UnaryOperationDescription, _bi: PhantomData, _bt: PhantomData, + _runtime: PhantomData, } - impl Operation for Cast + impl Operation + for Cast where - BInput: FusionBackend, - BTarget: FusionBackend, + BInput: FusionBackend, + BTarget: FusionBackend, + R: FusionRuntime, { fn execute( self: Box, @@ -82,7 +87,7 @@ where out.client.register( vec![stream], OperationDescription::BaseFloat(BaseOperationDescription::Cast(desc.clone())), - Cast::::new(desc), + Cast::::new(desc), ); out diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index c9674c8d14..9d5f945189 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -11,7 +11,7 @@ use burn_tensor::{ /// Define how to interact with the fusion server. pub trait FusionClient: Send + Sync + Clone + Sized where - R: FusionRuntime, + R: FusionRuntime, { /// Create a new client for the given [device](FusionRuntime::FusionDevice). fn new(device: FusionDevice) -> Self; diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 5b73e69e96..03159642da 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -31,7 +31,7 @@ where impl FusionClient for MutexFusionClient where - R: FusionRuntime, + R: FusionRuntime, { fn new(device: FusionDevice) -> Self { Self { diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 443755a9e6..919670524e 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -4,7 +4,7 @@ use crate::{ IntElement, JitBackend, Runtime, }; use burn_compute::client::ComputeClient; -use burn_fusion::{FusionBackend, FusionRuntime}; +use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; use half::{bf16, f16}; @@ -106,6 +106,7 @@ impl FusionRuntime for FusionJitRuntime { type Optimization = JitOptimization; type FusionHandle = JitFusionHandle; type FusionDevice = R::Device; + type FusionClient = MutexFusionClient; fn optimizations( device: R::Device, @@ -114,6 +115,7 @@ impl FusionRuntime for FusionJitRuntime { } } +#[derive(Debug)] pub struct FusionJitRuntime { _b: PhantomData, } diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index baae0412dc..850b07c81e 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -108,7 +108,7 @@ fn cargo_test(params: Params) { // Run cargo test run_cargo( "test", - params + "--color=always" + "--verbose" + "--" + "--color=always", + params + "--color=always" + "--" + "--color=always", HashMap::new(), "Failed to run cargo test", ); @@ -305,33 +305,28 @@ fn std_checks() { let disable_wgpu = std::env::var("DISABLE_WGPU").is_ok(); // Check format - // cargo_fmt(); + cargo_fmt(); // Check clippy lints - // cargo_clippy(); + cargo_clippy(); // Produce documentation for each workspace member - // group!("Docs: crates"); - // let mut params = Params::from(["--workspace", "--no-deps"]); - // // Exclude burn-cuda on all platforms - // params.params.push("--exclude".to_string()); - // params.params.push("burn-cuda".to_string()); - // cargo_doc(params); - // endgroup!(); + group!("Docs: crates"); + let mut params = Params::from(["--workspace", "--no-deps"]); + // Exclude burn-cuda on all platforms + params.params.push("--exclude".to_string()); + params.params.push("burn-cuda".to_string()); + cargo_doc(params); + endgroup!(); // Setup code coverage - // if is_coverage { - // setup_coverage(); - // } + if is_coverage { + setup_coverage(); + } // Build & test each member in workspace let members = get_workspace_members(WorkspaceMemberType::Crate); for member in members { - // TODO: FOR TEST CI - if member.name != "burn-wgpu" { - continue; - } - if disable_wgpu && member.name == "burn-wgpu" { continue; } @@ -344,7 +339,7 @@ fn std_checks() { } group!("Checks: {}", member.name); - // cargo_build(Params::from(["-p", &member.name])); + cargo_build(Params::from(["-p", &member.name])); cargo_test(Params::from(["-p", &member.name])); endgroup!(); } diff --git a/xtask/src/utils/cargo.rs b/xtask/src/utils/cargo.rs index aa835226e5..0480e6f1c6 100644 --- a/xtask/src/utils/cargo.rs +++ b/xtask/src/utils/cargo.rs @@ -25,7 +25,6 @@ pub(crate) fn run_cargo_with_path>( let mut cargo = Command::new("cargo"); cargo .env("CARGO_INCREMENTAL", "0") - .env("RUST_BACKTRACE", "full") .envs(&envs) .arg(command) .args(¶ms.params)