diff --git a/Cargo.lock b/Cargo.lock index 0e84add9db..b9f92d2321 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3776,14 +3776,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "refactor" -version = "0.14.0" -dependencies = [ - "burn", - "serde", -] - [[package]] name = "regex" version = "1.10.4" diff --git a/crates/burn-common/src/reader.rs b/crates/burn-common/src/reader.rs index 91f4492c1a..408b44c116 100644 --- a/crates/burn-common/src/reader.rs +++ b/crates/burn-common/src/reader.rs @@ -100,11 +100,11 @@ impl Reader { } /// Map the current reader to another type. - pub fn map O>(self, mapper: F) -> Reader + pub fn map(self, mapper: F) -> Reader where T: 'static + Send, O: 'static + Send, - F: 'static + Send, + F: FnOnce(T) -> O + 'static + Send, { #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] return Reader::Async(Box::new(MappedReader::new(self, mapper))); diff --git a/crates/burn-compute/src/compute.rs b/crates/burn-compute/src/compute.rs index d396d3fc7f..9a35f53841 100644 --- a/crates/burn-compute/src/compute.rs +++ b/crates/burn-compute/src/compute.rs @@ -8,6 +8,17 @@ pub struct ComputeRuntime { clients: spin::Mutex>>>, } +impl Default for ComputeRuntime +where + Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, + Server: ComputeServer, + Channel: ComputeChannel, +{ + fn default() -> Self { + Self::new() + } +} + impl ComputeRuntime where Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, diff --git a/crates/burn-core/src/record/serde/de.rs b/crates/burn-core/src/record/serde/de.rs index b429e1cd2a..5a09b3bde2 100644 --- a/crates/burn-core/src/record/serde/de.rs +++ b/crates/burn-core/src/record/serde/de.rs @@ -348,13 +348,13 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer { fn clone_unsafely(thing: &T) -> T { unsafe { // Allocate memory for the clone. - let clone = ptr::null_mut(); - // Correcting pointer usage based on feedback - let clone = ptr::addr_of_mut!(*clone); + let mut clone = std::mem::MaybeUninit::::uninit(); + // Get a mutable pointer to the allocated memory. + let clone_ptr = clone.as_mut_ptr(); // Copy the memory - ptr::copy_nonoverlapping(thing as *const T, clone, 1); - // Transmute the cloned data pointer into an owned instance of T. - ptr::read(clone) + ptr::copy_nonoverlapping(thing as *const T, clone_ptr, 1); + // Assume the cloned data is initialized and convert it to an owned instance of T. + clone.assume_init() } } diff --git a/crates/burn-core/src/record/serde/ser.rs b/crates/burn-core/src/record/serde/ser.rs index a30082b9d4..9c16d09a7a 100644 --- a/crates/burn-core/src/record/serde/ser.rs +++ b/crates/burn-core/src/record/serde/ser.rs @@ -52,13 +52,13 @@ impl SerializerTrait for Serializer { Ok(self) } - fn serialize_newtype_struct( + fn serialize_newtype_struct( self, _name: &'static str, value: &T, ) -> Result where - T: Serialize, + T: Serialize + ?Sized, { value.serialize(self) } @@ -128,9 +128,9 @@ impl SerializerTrait for Serializer { unimplemented!() } - fn serialize_some(self, value: &T) -> Result + fn serialize_some(self, value: &T) -> Result where - T: Serialize, + T: Serialize + ?Sized, { value.serialize(self) } @@ -152,7 +152,7 @@ impl SerializerTrait for Serializer { unimplemented!() } - fn serialize_newtype_variant( + fn serialize_newtype_variant( self, _name: &'static str, _variant_index: u32, @@ -160,7 +160,7 @@ impl SerializerTrait for Serializer { _value: &T, ) -> Result where - T: Serialize, + T: Serialize + ?Sized, { unimplemented!() } @@ -207,13 +207,9 @@ impl SerializeStruct for Serializer { type Ok = NestedValue; type Error = Error; - fn serialize_field( - &mut self, - key: &'static str, - value: &T, - ) -> Result<(), Self::Error> + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> where - T: Serialize, + T: Serialize + ?Sized, { let serialized_value = value.serialize(Serializer::new())?; @@ -248,9 +244,9 @@ impl SerializeSeq for Serializer { type Ok = NestedValue; type Error = Error; - fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> where - T: Serialize, + T: Serialize + ?Sized, { let serialized_value = value.serialize(Serializer::new())?; diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index 3d4167e5cc..e5b8cb1a66 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -2,7 +2,8 @@ use crate::{ client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge, }; use burn_tensor::{ - backend::Backend, + backend::{Backend, DeviceOps}, + ops::FloatTensor, repr::{OperationDescription, ReprBackend}, Device, }; @@ -11,30 +12,30 @@ use std::marker::PhantomData; pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new(); -pub(crate) fn get_client(device: &B::Device) -> B::FusionClient { - CLIENTS.client(device) +pub(crate) fn get_client(device: &Device) -> Client { + CLIENTS.client::(device) } /// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend). #[derive(Clone, Debug, Default)] -pub struct Fusion { +pub struct Fusion { _backend: PhantomData, } impl Backend for Fusion { type Device = B::Device; - type FullPrecisionBridge = PrecisionBridge; + type FullPrecisionBridge = PrecisionBridge; - type FloatTensorPrimitive = FusionTensor; + type FloatTensorPrimitive = FusionTensor; type FloatElem = B::FloatElem; - type IntTensorPrimitive = FusionTensor; + type IntTensorPrimitive = FusionTensor; type IntElem = B::IntElem; - type BoolTensorPrimitive = FusionTensor; + type BoolTensorPrimitive = FusionTensor; fn name() -> String { format!("fusion<{}>", B::name()) @@ -45,10 +46,14 @@ impl Backend for Fusion { } fn sync(device: &Self::Device) { - let client = CLIENTS.client::(&device.clone()); + let client = CLIENTS.client::(&device.clone()); client.drain(); B::sync(device) } + + fn ad_enabled() -> bool { + false + } } /// The status of a [builder](OptimizationBuilder). @@ -101,9 +106,9 @@ pub trait OptimizationBuilder: Send { } /// The operation created from the [builder](OptimizationBuilder). -pub trait Optimization: Send { +pub trait Optimization: Send { /// Execute the operation. - fn execute(&mut self, context: &mut Context<'_, B>); + fn execute(&mut self, context: &mut Context<'_, R::FusionHandle>); /// The number of registered operations in this optimization. fn len(&self) -> usize; /// If the current optimization is empty. @@ -111,22 +116,51 @@ pub trait Optimization: Send { self.len() == 0 } /// Returns the state that can be serialized. - fn to_state(&self) -> B::OptimizationState; + fn to_state(&self) -> R::OptimizationState; /// Create the optimization from the state. - fn from_state(device: &B::Device, state: B::OptimizationState) -> Self; + fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self; } -/// Trait that allows an existing [backend](Backend) to specify graph optimizations using -/// [operation builder](crate::OptimizationBuilder). -pub trait FusionBackend: Backend + ReprBackend { +/// Type alias for `::FusionDevice`. +pub type FusionDevice = ::FusionDevice; +/// Type alias for `::FusionHandle`. +pub type FusionHandle = ::FusionHandle; +/// Type alias for `::FusionClient`. +pub type Client = ::FusionClient; + +/// Trait that defines a runtime that will benefits from fused operations. +pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug { /// The state that can be serialized for an optimization. type OptimizationState: Serialize + DeserializeOwned; /// Optimization type for the backend. type Optimization: Optimization; - /// What kind of client should be used. - type FusionClient: FusionClient; + /// Handle used to store tensor dynamically. + type FusionHandle: Clone + Send; + /// Device used by the runtime. + type FusionDevice: DeviceOps; + /// The client to interact with the runtime. + type FusionClient: FusionClient; /// The list of optimizations that will be used to optimize the computational graph. - fn optimizations(device: Device) - -> Vec>>; + fn optimizations( + device: Self::FusionDevice, + ) -> Vec>>; +} + +/// Trait that allows an existing [backend](Backend) to specify graph optimizations using +/// [operation builder](crate::OptimizationBuilder). +pub trait FusionBackend: + ReprBackend, Device = FusionDevice> +{ + /// The runtime used for this backend. + type FusionRuntime: FusionRuntime; + + /// Cast a float tensor and returns the resulting handle. + fn cast_float( + tensor: FloatTensor, + dtype: burn_tensor::DType, + ) -> Self::Handle; + + /// Pointer to the full precision fusion backend. + type FullPrecisionBackend: FusionBackend; } diff --git a/crates/burn-fusion/src/bridge.rs b/crates/burn-fusion/src/bridge.rs index 375fd4fb52..82888075a6 100644 --- a/crates/burn-fusion/src/bridge.rs +++ b/crates/burn-fusion/src/bridge.rs @@ -1,25 +1,94 @@ -use burn_tensor::backend::BackendBridge; - -use crate::{Fusion, FusionBackend}; +use crate::{ + client::FusionClient, stream::execution::Operation, Fusion, FusionBackend, FusionRuntime, +}; +use burn_tensor::{ + backend::BackendBridge, + ops::FloatTensor, + repr::{ + BaseOperationDescription, HandleContainer, OperationDescription, UnaryOperationDescription, + }, + Element, +}; +use std::marker::PhantomData; #[derive(Debug)] /// Fusion bridge. -pub struct PrecisionBridge; +pub struct PrecisionBridge { + _backend: PhantomData, +} -impl BackendBridge> for PrecisionBridge { - type Target = Fusion; +impl BackendBridge> for PrecisionBridge +where + BInput: FusionBackend, + BTarget: FusionBackend, + R: FusionRuntime + 'static, +{ + type Target = Fusion; fn into_target( - tensor: burn_tensor::ops::FloatTensor, D>, + tensor: FloatTensor, D>, _device: Option>, - ) -> burn_tensor::ops::FloatTensor { - tensor + ) -> FloatTensor { + cast::(tensor) } fn from_target( - tensor: burn_tensor::ops::FloatTensor, - _device: Option>>, - ) -> burn_tensor::ops::FloatTensor, D> { - tensor + tensor: FloatTensor, + _device: Option>>, + ) -> FloatTensor, D> { + cast::(tensor) + } +} + +fn cast( + input: FloatTensor, D>, +) -> FloatTensor, D> +where + BInput: FusionBackend, + BTarget: FusionBackend, + R: FusionRuntime + 'static, +{ + #[derive(new)] + struct Cast { + desc: UnaryOperationDescription, + _bi: PhantomData, + _bt: PhantomData, + _runtime: PhantomData, + } + + impl Operation + for Cast + where + BInput: FusionBackend, + BTarget: FusionBackend, + R: FusionRuntime, + { + fn execute( + self: Box, + handles: &mut HandleContainer<::FusionHandle>, + ) { + let input = handles.get_float_tensor::(&self.desc.input); + let output = BInput::cast_float(input, BTarget::FloatElem::dtype()); + + handles.register_handle(self.desc.out.id, output); + } } + + let stream = input.stream; + let out = input + .client + .tensor_uninitialized(input.shape.clone(), BTarget::FloatElem::dtype()); + + let desc = UnaryOperationDescription { + input: input.into_description(), + out: out.to_description_out(), + }; + + out.client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::Cast(desc.clone())), + Cast::::new(desc), + ); + + out } diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 377f98b8f0..9d5f945189 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -1,80 +1,89 @@ use crate::{ stream::{execution::Operation, StreamId}, - FusionBackend, FusionTensor, Handle, + FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, }; use burn_tensor::{ - backend::Backend, ops::{FloatElem, IntElem}, repr::{OperationDescription, TensorDescription, TensorId}, - Data, Device, Reader, + DType, Data, Reader, }; /// Define how to interact with the fusion server. -pub trait FusionClient: Send + Sync + Clone { - /// The [fusion backend](FusionBackend) associated type. - type FusionBackend: FusionBackend; - - /// Create a new client for the given [device](Backend::Device). - fn new(device: Device) -> Self; +pub trait FusionClient: Send + Sync + Clone + Sized +where + R: FusionRuntime, +{ + /// Create a new client for the given [device](FusionRuntime::FusionDevice). + fn new(device: FusionDevice) -> Self; /// Register a new [tensor operation description](OperationDescription). - fn register + 'static>( - &self, - streams: Vec, - description: OperationDescription, - operation: O, - ); + fn register(&self, streams: Vec, description: OperationDescription, operation: O) + where + O: Operation + 'static; /// Register all lazy computation. fn drain(&self); /// Get the current device used by all operations handled by this client. - fn device(&self) -> &::Device; + fn device(&self) -> &FusionDevice; /// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it. - fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor; + fn tensor_uninitialized(&self, shape: Vec, dtype: DType) -> FusionTensor; /// Create a tensor with the given handle and shape. fn register_tensor( &self, - handle: Handle, + handle: FusionHandle, shape: Vec, stream: StreamId, - ) -> FusionTensor; + dtype: DType, + ) -> FusionTensor; /// Read the values contained by a float tensor. - fn read_tensor_float( + fn read_tensor_float( &self, tensor: TensorDescription, stream: StreamId, - ) -> Reader, D>>; + ) -> Reader, D>> + where + B: FusionBackend; /// Read the values contained by an int tensor. - fn read_tensor_int( + fn read_tensor_int( &self, tensor: TensorDescription, stream: StreamId, - ) -> Reader, D>>; + ) -> Reader, D>> + where + B: FusionBackend; /// Read the values contained by a bool tensor. - fn read_tensor_bool( + fn read_tensor_bool( &self, tensor: TensorDescription, stream: StreamId, - ) -> Reader>; + ) -> Reader> + where + B: FusionBackend; /// Change the client of the given float tensor. - fn change_client_float( + fn change_client_float( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor; + ) -> FusionTensor + where + B: FusionBackend; /// Change the client of the given int tensor. - fn change_client_int( + fn change_client_int( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor; + ) -> FusionTensor + where + B: FusionBackend; /// Change the client of the given bool tensor. - fn change_client_bool( + fn change_client_bool( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor; + ) -> FusionTensor + where + B: FusionBackend; /// Drop the tensor with the given [tensor id](TensorId). fn register_orphan(&self, id: &TensorId); } diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 940d331f0f..03159642da 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -1,28 +1,25 @@ use super::FusionClient; use crate::{ stream::{execution::Operation, StreamId}, - FusionBackend, FusionServer, FusionTensor, Handle, + FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor, }; use burn_tensor::{ - backend::Backend, ops::FloatElem, repr::{OperationDescription, TensorDescription, TensorId}, + DType, }; use spin::Mutex; use std::sync::Arc; /// Use a mutex to communicate with the fusion server. -pub struct MutexFusionClient -where - B: FusionBackend, -{ - server: Arc>>, - device: B::Device, +pub struct MutexFusionClient { + server: Arc>>, + device: FusionDevice, } -impl Clone for MutexFusionClient +impl Clone for MutexFusionClient where - B: FusionBackend, + R: FusionRuntime, { fn clone(&self) -> Self { Self { @@ -32,25 +29,21 @@ where } } -impl FusionClient for MutexFusionClient +impl FusionClient for MutexFusionClient where - B: FusionBackend, + R: FusionRuntime, { - type FusionBackend = B; - - fn new(device: B::Device) -> Self { + fn new(device: FusionDevice) -> Self { Self { device: device.clone(), server: Arc::new(Mutex::new(FusionServer::new(device))), } } - fn register + 'static>( - &self, - streams: Vec, - description: OperationDescription, - operation: O, - ) { + fn register(&self, streams: Vec, description: OperationDescription, operation: O) + where + O: Operation + 'static, + { self.server .lock() .register(streams, description, Box::new(operation)) @@ -61,107 +54,128 @@ where self.server.lock().drain_stream(id); } - fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor { + fn tensor_uninitialized(&self, shape: Vec, dtype: DType) -> FusionTensor { let id = self.server.lock().create_empty_handle(); - FusionTensor::new(id, shape, self.clone(), StreamId::current()) + FusionTensor::new(id, shape, dtype, self.clone(), StreamId::current()) } - fn device(&self) -> &::Device { + fn device(&self) -> &FusionDevice { &self.device } + fn register_tensor( &self, - handle: Handle, + handle: FusionHandle, shape: Vec, stream: StreamId, - ) -> FusionTensor { + dtype: DType, + ) -> FusionTensor { let mut server = self.server.lock(); let id = server.create_empty_handle(); server.handles.register_handle(*id.as_ref(), handle); core::mem::drop(server); - FusionTensor::new(id, shape, self.clone(), stream) + FusionTensor::new(id, shape, dtype, self.clone(), stream) } - fn read_tensor_float( + fn read_tensor_float( &self, tensor: TensorDescription, stream: StreamId, - ) -> burn_tensor::Reader, D>> { - self.server.lock().read_float(tensor, stream) + ) -> burn_tensor::Reader, D>> + where + B: FusionBackend, + { + self.server.lock().read_float::(tensor, stream) } - fn read_tensor_int( + fn read_tensor_int( &self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader, D>> + ) -> burn_tensor::Reader, D>> + where + B: FusionBackend, { - self.server.lock().read_int(tensor, id) + self.server.lock().read_int::(tensor, id) } - fn read_tensor_bool( + fn read_tensor_bool( &self, tensor: TensorDescription, stream: StreamId, - ) -> burn_tensor::Reader> { - self.server.lock().read_bool(tensor, stream) + ) -> burn_tensor::Reader> + where + B: FusionBackend, + { + self.server.lock().read_bool::(tensor, stream) } - fn change_client_float( + fn change_client_float( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor { + ) -> FusionTensor + where + B: FusionBackend, + { let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); server_current.drain_stream(stream); let id = - server_current.change_server_float::(&tensor, &client.device, &mut server_other); + server_current.change_server_float::(&tensor, &client.device, &mut server_other); core::mem::drop(server_other); core::mem::drop(server_current); - FusionTensor::new(id, tensor.shape, client, StreamId::current()) + FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current()) } - fn change_client_int( + fn change_client_int( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor { + ) -> FusionTensor + where + B: FusionBackend, + { let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); server_current.drain_stream(stream); - let id = server_current.change_server_int::(&tensor, &client.device, &mut server_other); + let id = + server_current.change_server_int::(&tensor, &client.device, &mut server_other); core::mem::drop(server_other); core::mem::drop(server_current); - FusionTensor::new(id, tensor.shape, client, StreamId::current()) + FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current()) } - fn change_client_bool( + fn change_client_bool( &self, tensor: TensorDescription, client: Self, stream: StreamId, - ) -> FusionTensor { + ) -> FusionTensor + where + B: FusionBackend, + { let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); server_current.drain_stream(stream); - let id = server_current.change_server_bool::(&tensor, &client.device, &mut server_other); + let id = + server_current.change_server_bool::(&tensor, &client.device, &mut server_other); core::mem::drop(server_other); core::mem::drop(server_current); - FusionTensor::new(id, tensor.shape, client, StreamId::current()) + FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current()) } fn register_orphan(&self, id: &TensorId) { diff --git a/crates/burn-fusion/src/fusion.rs b/crates/burn-fusion/src/fusion.rs index f224fdaa33..a99b508d0d 100644 --- a/crates/burn-fusion/src/fusion.rs +++ b/crates/burn-fusion/src/fusion.rs @@ -1,9 +1,9 @@ use burn_tensor::{ - backend::{Backend, DeviceId, DeviceOps}, + backend::{DeviceId, DeviceOps}, repr::ReprBackend, }; -use crate::client::FusionClient; +use crate::{client::FusionClient, Client, FusionDevice, FusionRuntime}; use std::{any::Any, collections::HashMap, ops::DerefMut}; @@ -26,27 +26,24 @@ impl FusionClientLocator { /// Get the fusion client for the given device. /// /// Provide the init function to create a new client if it isn't already initialized. - pub fn client( - &self, - device: &::Device, - ) -> C { + pub fn client(&self, device: &FusionDevice) -> Client { let device_id = device.id(); - let client_id = (core::any::TypeId::of::(), device_id); + let client_id = (core::any::TypeId::of::(), device_id); let mut clients = self.clients.lock(); if clients.is_none() { - let client = C::new(device.clone()); - Self::register_inner::(client_id, client, &mut clients); + let client = Client::::new(device.clone()); + Self::register_inner::(client_id, client, &mut clients); } match clients.deref_mut() { Some(clients) => match clients.get(&client_id) { Some(client) => { - let client: &C = client.downcast_ref().unwrap(); + let client: &Client = client.downcast_ref().unwrap(); client.clone() } None => { - let client = C::new(device.clone()); + let client = Client::::new(device.clone()); let any = Box::new(client.clone()); clients.insert(client_id, any); client @@ -56,9 +53,9 @@ impl FusionClientLocator { } } - fn register_inner( + fn register_inner( key: Key, - client: C, + client: Client, clients: &mut Option>>, ) { if clients.is_none() { diff --git a/crates/burn-fusion/src/ops/binary.rs b/crates/burn-fusion/src/ops/binary.rs index e7035f674e..a7ed09ab28 100644 --- a/crates/burn-fusion/src/ops/binary.rs +++ b/crates/burn-fusion/src/ops/binary.rs @@ -6,17 +6,18 @@ macro_rules! binary_float_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: BinaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); - let rhs = handles.get_float_tensor(&self.desc.rhs); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); + let rhs = handles.get_float_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } }; @@ -30,17 +31,18 @@ macro_rules! binary_float_cmp_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: BinaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); - let rhs = handles.get_float_tensor(&self.desc.rhs); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); + let rhs = handles.get_float_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } }; @@ -54,17 +56,18 @@ macro_rules! binary_int_cmp_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: BinaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_int_tensor::(&self.desc.lhs); - let rhs = handles.get_int_tensor(&self.desc.rhs); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_int_tensor::(&self.desc.lhs); + let rhs = handles.get_int_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } }; @@ -88,17 +91,18 @@ macro_rules! binary_int_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: BinaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_int_tensor::(&self.desc.lhs); - let rhs = handles.get_int_tensor(&self.desc.rhs); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_int_tensor::(&self.desc.lhs); + let rhs = handles.get_int_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } }; diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index 4da93e0dfe..23b136b35d 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -1,3 +1,6 @@ +use burn_tensor::{DType, Element}; +use std::marker::PhantomData; + use crate::{ client::FusionClient, get_client, @@ -26,6 +29,7 @@ impl BoolTensorOps for Fusion { B::bool_tensor_handle(tensor), shape.dims.into(), StreamId::current(), + DType::Bool, ) } @@ -36,7 +40,7 @@ impl BoolTensorOps for Fusion { fn bool_into_data( tensor: BoolTensor, ) -> burn_tensor::Reader> { - tensor.bool_into_data() + tensor.bool_into_data::() } fn bool_from_data( @@ -51,6 +55,7 @@ impl BoolTensorOps for Fusion { B::bool_tensor_handle(tensor), shape.dims.into(), StreamId::current(), + DType::Bool, ) } @@ -58,20 +63,23 @@ impl BoolTensorOps for Fusion { tensor: BoolTensor, ) -> burn_tensor::ops::IntTensor { #[derive(new)] - struct IntoIntOps { + struct IntoIntOps { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for IntoIntOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + impl Operation for IntoIntOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_into_int(input); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -81,7 +89,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Bool(BoolOperationDescription::IntoInt(desc.clone())), - IntoIntOps::::new(desc), + IntoIntOps::::new(desc), ); out @@ -91,20 +99,23 @@ impl BoolTensorOps for Fusion { tensor: BoolTensor, ) -> burn_tensor::ops::FloatTensor { #[derive(new)] - struct IntoFloatOps { + struct IntoFloatOps { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for IntoFloatOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + impl Operation for IntoFloatOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_into_float(input); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -113,7 +124,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Bool(BoolOperationDescription::IntoFloat(desc.clone())), - IntoFloatOps::::new(desc), + IntoFloatOps::::new(desc), ); out @@ -138,7 +149,7 @@ impl BoolTensorOps for Fusion { let client_target = get_client::(&device_target); let client_original = tensor.client.clone(); - client_original.clone().change_client_bool::( + client_original.clone().change_client_bool::( tensor.into_description(), client_target, id, @@ -150,21 +161,24 @@ impl BoolTensorOps for Fusion { shape: Shape, ) -> BoolTensor { #[derive(new)] - struct ReshapeDimsOps { + struct ReshapeDimsOps { desc: ReshapeDescription, + _b: PhantomData, } - impl Operation for ReshapeDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + impl Operation + for ReshapeDimsOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_reshape::(input, Shape::from(&self.desc.out.shape)); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = ReshapeDescription { input: tensor.into_description(), @@ -173,7 +187,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::Reshape(desc.clone())), - ReshapeDimsOps::::new(desc), + ReshapeDimsOps::::new(desc), ); out @@ -184,18 +198,21 @@ impl BoolTensorOps for Fusion { ranges: [std::ops::Range; D2], ) -> BoolTensor { #[derive(new)] - struct SliceOps { + struct SliceOps { desc: SliceOperationDescription, + _b: PhantomData, } - impl Operation for SliceOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_bool_tensor::(&self.desc.tensor); + impl Operation + for SliceOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_bool_tensor::(&self.desc.tensor); let output = B::bool_slice::(tensor, self.desc.ranges.clone().try_into().unwrap()); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -206,7 +223,7 @@ impl BoolTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = SliceOperationDescription { tensor: tensor.into_description(), @@ -216,7 +233,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::Slice(desc.clone())), - SliceOps::::new(desc), + SliceOps::::new(desc), ); out @@ -228,14 +245,17 @@ impl BoolTensorOps for Fusion { value: BoolTensor, ) -> BoolTensor { #[derive(new)] - struct SliceAssignOps { + struct SliceAssignOps { desc: SliceAssignOperationDescription, + _b: PhantomData, } - impl Operation for SliceAssignOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_bool_tensor::(&self.desc.tensor); - let value = handles.get_bool_tensor::(&self.desc.value); + impl Operation + for SliceAssignOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_bool_tensor::(&self.desc.tensor); + let value = handles.get_bool_tensor::(&self.desc.value); let output = B::bool_slice_assign::( tensor, @@ -243,14 +263,14 @@ impl BoolTensorOps for Fusion { value, ); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } let shape: Vec = tensor.shape.clone(); let stream_1 = tensor.stream; let stream_2 = value.stream; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = SliceAssignOperationDescription { tensor: tensor.into_description(), @@ -262,7 +282,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseBool(BaseOperationDescription::SliceAssign(desc.clone())), - SliceAssignOps::::new(desc), + SliceAssignOps::::new(desc), ); out @@ -273,22 +293,23 @@ impl BoolTensorOps for Fusion { dim: usize, ) -> BoolTensor { #[derive(new)] - struct CatOps { + struct CatOps { desc: CatOperationDescription, + _b: PhantomData, } - impl Operation for CatOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for CatOps { + fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc .tensors .iter() - .map(|tensor| handles.get_bool_tensor(tensor)) + .map(|tensor| handles.get_bool_tensor::(tensor)) .collect(); let output = B::bool_cat::(tensors, self.desc.dim); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -304,7 +325,7 @@ impl BoolTensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, DType::Bool); let desc = CatOperationDescription { tensors: tensors.into_iter().map(|t| t.into_description()).collect(), @@ -314,7 +335,7 @@ impl BoolTensorOps for Fusion { client.register( streams, OperationDescription::BaseBool(BaseOperationDescription::Cat(desc.clone())), - CatOps::::new(desc), + CatOps::::new(desc), ); out @@ -325,16 +346,17 @@ impl BoolTensorOps for Fusion { rhs: BoolTensor, ) -> BoolTensor { #[derive(new)] - struct EqualOps { + struct EqualOps { desc: BinaryOperationDescription, + _b: PhantomData, } - impl Operation for EqualOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_bool_tensor::(&self.desc.lhs); - let rhs = handles.get_bool_tensor(&self.desc.rhs); + impl Operation for EqualOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_bool_tensor::(&self.desc.lhs); + let rhs = handles.get_bool_tensor::(&self.desc.rhs); let output = B::bool_equal(lhs, rhs); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -342,7 +364,7 @@ impl BoolTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -352,7 +374,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseBool(BaseOperationDescription::Equal(desc.clone())), - EqualOps::::new(desc), + EqualOps::::new(desc), ); out @@ -360,20 +382,23 @@ impl BoolTensorOps for Fusion { fn bool_not(tensor: BoolTensor) -> BoolTensor { #[derive(new)] - struct NotOps { + struct NotOps { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for NotOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + impl Operation for NotOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_not(input); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), DType::Bool); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -383,7 +408,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Bool(BoolOperationDescription::Not(desc.clone())), - NotOps::::new(desc), + NotOps::::new(desc), ); out @@ -395,15 +420,16 @@ impl BoolTensorOps for Fusion { dim2: usize, ) -> BoolTensor { #[derive(new)] - struct SwapDimsOps { + struct SwapDimsOps { desc: SwapDimsDescription, + _b: PhantomData, } - impl Operation for SwapDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + impl Operation for SwapDimsOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -412,7 +438,7 @@ impl BoolTensorOps for Fusion { shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = SwapDimsDescription { input: tensor.into_description(), @@ -423,7 +449,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::SwapDims(desc.clone())), - SwapDimsOps::::new(desc), + SwapDimsOps::::new(desc), ); out @@ -434,16 +460,17 @@ impl BoolTensorOps for Fusion { axes: [usize; D], ) -> BoolTensor { #[derive(new)] - struct PermuteDimsOps { + struct PermuteDimsOps { desc: PermuteOperationDescription, + _b: PhantomData, } - impl Operation for PermuteDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + impl Operation for PermuteDimsOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); let output = B::bool_permute(input, axes); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } @@ -452,7 +479,7 @@ impl BoolTensorOps for Fusion { // Change the shape of the tensor to match the new axes let shape = axes.into_iter().map(|x| tensor.shape[x]).collect(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = PermuteOperationDescription { input: tensor.into_description(), @@ -463,7 +490,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Permute(desc.clone())), - PermuteDimsOps::::new(desc), + PermuteDimsOps::::new(desc), ); out @@ -474,23 +501,28 @@ impl BoolTensorOps for Fusion { shape: Shape, ) -> BoolTensor { #[derive(new)] - struct ExpandOps { + struct ExpandOps { desc: ExpandOperationDescription, + _b: PhantomData, } - impl Operation for ExpandOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + impl Operation + for ExpandOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); let output = B::bool_expand(input, shape.into()); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(shape.dims.into()); + let out = tensor + .client + .tensor_uninitialized(shape.dims.into(), DType::Bool); let desc = ExpandOperationDescription { input: tensor.into_description(), @@ -501,7 +533,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::Expand(desc.clone())), - ExpandOps::::new(desc), + ExpandOps::::new(desc), ); out @@ -512,20 +544,23 @@ impl BoolTensorOps for Fusion { axes: &[usize], ) -> BoolTensor { #[derive(new)] - struct FlipOps { + struct FlipOps { desc: FlipOperationDescription, + _b: PhantomData, } - impl Operation for FlipOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + impl Operation for FlipOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_flip(input, self.desc.axes.as_slice()); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), DType::Bool); let desc = FlipOperationDescription { input: tensor.into_description(), @@ -536,7 +571,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::Flip(desc.clone())), - FlipOps::::new(desc), + FlipOps::::new(desc), ); out @@ -548,24 +583,25 @@ impl BoolTensorOps for Fusion { times: usize, ) -> BoolTensor { #[derive(new)] - struct RepeatOps { + struct RepeatOps { desc: RepeatOperationDescription, + _b: PhantomData, } - impl Operation for RepeatOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_bool_tensor::(&self.desc.tensor); + impl Operation for RepeatOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_bool_tensor::(&self.desc.tensor); let output = B::bool_repeat::(tensor, self.desc.dim, self.desc.times); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] *= times; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let desc = RepeatOperationDescription { tensor: tensor.into_description(), @@ -576,7 +612,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseBool(BaseOperationDescription::Repeat(desc.clone())), - RepeatOps::::new(desc), + RepeatOps::::new(desc), ); out diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index a24f7045b5..e018ae632c 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -7,23 +7,12 @@ use crate::{ stream::{execution::Operation, StreamId}, unary_float_ops, Fusion, FusionBackend, }; - use burn_tensor::{ ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, - repr::{ - BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, - ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - FloatOperationDescription, GatherOperationDescription, HandleContainer, - MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, - OperationDescription, PermuteOperationDescription, RandomOperationDescription, - ReduceDimWithIndicesDescription, RepeatOperationDescription, ReshapeDescription, - ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, - SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, - SwapDimsDescription, TensorDescription, UnaryOperationDescription, - }, - Data, Device, Distribution, ElementConversion, Reader, Shape, + repr::*, + DType, Data, Device, Distribution, Element, ElementConversion, Reader, Shape, }; -use std::ops::Range; +use std::{marker::PhantomData, ops::Range}; impl FloatTensorOps for Fusion { fn float_from_data( @@ -38,6 +27,7 @@ impl FloatTensorOps for Fusion { B::float_tensor_handle(tensor), shape.dims.into(), StreamId::current(), + B::FloatElem::dtype(), ) } @@ -47,23 +37,24 @@ impl FloatTensorOps for Fusion { device: &Device, ) -> FloatTensor { #[derive(new)] - struct RandomOps { + struct RandomOps { desc: RandomOperationDescription, + device: Device, } - impl Operation for RandomOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for RandomOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.out.shape.clone()); let output: B::FloatTensorPrimitive = - B::float_random(shape, self.desc.distribution, &handles.device); - handles.register_float_tensor(&self.desc.out.id, output); + B::float_random(shape, self.desc.distribution, &self.device); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = RandomOperationDescription { out: out.to_description_out(), @@ -72,7 +63,7 @@ impl FloatTensorOps for Fusion { client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Random(desc.clone())), - RandomOps::::new(desc), + RandomOps::::new(desc, device.clone()), ); out @@ -80,28 +71,29 @@ impl FloatTensorOps for Fusion { fn float_zeros(shape: Shape, device: &Device) -> FloatTensor { #[derive(new)] - struct ZerosOps { + struct ZerosOps { out: TensorDescription, + device: Device, } - impl Operation for ZerosOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for ZerosOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); - let output = B::float_zeros::(shape, &handles.device); - handles.register_float_tensor(&self.out.id, output); + let output = B::float_zeros::(shape, &self.device); + handles.register_float_tensor::(&self.out.id, output); } } let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = out.to_description_out(); client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Zeros(desc.clone())), - ZerosOps::::new(desc), + ZerosOps::::new(desc, device.clone()), ); out @@ -109,28 +101,29 @@ impl FloatTensorOps for Fusion { fn float_ones(shape: Shape, device: &Device) -> FloatTensor { #[derive(new)] - struct OnesOps { + struct OnesOps { out: TensorDescription, + device: Device, } - impl Operation for OnesOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for OnesOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); - let output = B::float_ones::(shape, &handles.device); - handles.register_float_tensor(&self.out.id, output); + let output = B::float_ones::(shape, &self.device); + handles.register_float_tensor::(&self.out.id, output); } } let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = out.to_description_out(); client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Ones(desc.clone())), - OnesOps::::new(desc), + OnesOps::::new(desc, device.clone()), ); out @@ -142,30 +135,31 @@ impl FloatTensorOps for Fusion { device: &Device, ) -> FloatTensor { #[derive(new)] - struct FullOps { + struct FullOps { out: TensorDescription, elem: f32, + device: Device, } - impl Operation for FullOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for FullOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); let output: B::FloatTensorPrimitive = - B::float_full(shape, self.elem.elem(), &handles.device); - handles.register_float_tensor(&self.out.id, output); + B::float_full(shape, self.elem.elem(), &self.device); + handles.register_float_tensor::(&self.out.id, output); } } let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = (out.to_description_out(), fill_value.elem::()); client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Full(desc.clone())), - FullOps::::new(desc.0, desc.1), + FullOps::::new(desc.0, desc.1, device.clone()), ); out @@ -178,7 +172,7 @@ impl FloatTensorOps for Fusion { fn float_into_data( tensor: FloatTensor, ) -> Reader, D>> { - tensor.into_data() + tensor.into_data::() } fn float_device(tensor: &FloatTensor) -> Device { @@ -200,7 +194,7 @@ impl FloatTensorOps for Fusion { let client_target = get_client::(&device_target); let client_original = tensor.client.clone(); - client_original.clone().change_client_float::( + client_original.clone().change_client_float::( tensor.into_description(), client_target, id, @@ -209,21 +203,24 @@ impl FloatTensorOps for Fusion { fn float_into_int(tensor: FloatTensor) -> IntTensor { #[derive(new)] - struct IntoIntOps { + struct IntoIntOps { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for IntoIntOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + impl Operation for IntoIntOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_into_int(input); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -232,7 +229,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::IntoInt(desc.clone())), - IntoIntOps::::new(desc), + IntoIntOps::::new(desc), ); out @@ -243,7 +240,12 @@ impl FloatTensorOps for Fusion { let stream = StreamId::current(); let tensor = B::float_empty(shape.clone(), device); - client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into(), stream) + client.register_tensor( + B::float_tensor_handle(tensor), + shape.dims.into(), + stream, + B::FloatElem::dtype(), + ) } fn float_add( @@ -254,9 +256,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::FloatElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -267,7 +270,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Add(desc.clone())), - AddOps::::new(desc), + AddOps::::new(desc), ); out @@ -280,7 +283,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(AddOps, B::float_add_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -292,7 +297,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::AddScalar( desc.clone(), )), - AddOps::::new(desc), + AddOps::::new(desc), ); out @@ -304,21 +309,24 @@ impl FloatTensorOps for Fusion { max: FloatElem, ) -> FloatTensor { #[derive(new)] - struct ClampOps { + struct ClampOps { desc: ClampOperationDescription, + _b: PhantomData, } - impl Operation for ClampOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.tensor); + impl Operation for ClampOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_clamp(input, self.desc.min.elem(), self.desc.max.elem()); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = ClampOperationDescription { tensor: tensor.into_description(), @@ -329,7 +337,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Clamp(desc.clone())), - ClampOps::::new(desc), + ClampOps::::new(desc), ); out @@ -343,9 +351,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::FloatElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -355,7 +364,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Sub(desc.clone())), - SubOps::::new(desc), + SubOps::::new(desc), ); out @@ -368,7 +377,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(SubOps, B::float_sub_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), rhs: rhs.elem(), @@ -380,7 +391,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::SubScalar( desc.clone(), )), - SubOps::::new(desc), + SubOps::::new(desc), ); out @@ -394,9 +405,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::FloatElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -406,7 +418,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Mul(desc.clone())), - MulOps::::new(desc), + MulOps::::new(desc), ); out @@ -419,7 +431,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(MulOps, B::float_mul_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -431,7 +445,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::MulScalar( desc.clone(), )), - MulOps::::new(desc), + MulOps::::new(desc), ); out @@ -445,9 +459,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::FloatElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -457,7 +472,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Div(desc.clone())), - DivOps::::new(desc), + DivOps::::new(desc), ); out @@ -470,7 +485,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(DivOps, B::float_div_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -482,7 +499,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::DivScalar( desc.clone(), )), - DivOps::::new(desc), + DivOps::::new(desc), ); out @@ -495,7 +512,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(ModOps, B::float_remainder_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -507,7 +526,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::RemScalar( desc.clone(), )), - ModOps::::new(desc), + ModOps::::new(desc), ); out @@ -526,7 +545,9 @@ impl FloatTensorOps for Fusion { shape[D - 2] = lhs.shape[D - 2]; shape[D - 1] = rhs.shape[D - 1]; - let out = lhs.client.tensor_uninitialized(shape); + let out = lhs + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = BinaryOperationDescription { lhs: lhs.into_description(), rhs: rhs.into_description(), @@ -536,7 +557,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::Float(FloatOperationDescription::Matmul(desc.clone())), - MatmulOps::::new(desc), + MatmulOps::::new(desc), ); out @@ -548,15 +569,16 @@ impl FloatTensorOps for Fusion { dim2: usize, ) -> FloatTensor { #[derive(new)] - struct SwapDimsOps { + struct SwapDimsOps { desc: SwapDimsDescription, + _b: PhantomData, } - impl Operation for SwapDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + impl Operation for SwapDimsOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -565,7 +587,9 @@ impl FloatTensorOps for Fusion { shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let mut out = tensor.client.tensor_uninitialized(shape); + let mut out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = SwapDimsDescription { input: tensor.into_description(), @@ -576,7 +600,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseFloat(BaseOperationDescription::SwapDims(desc.clone())), - SwapDimsOps::::new(desc), + SwapDimsOps::::new(desc), ); out.stream = stream; @@ -588,21 +612,26 @@ impl FloatTensorOps for Fusion { shape: Shape, ) -> FloatTensor { #[derive(new)] - struct ReshapeDimsOps { + struct ReshapeDimsOps { desc: ReshapeDescription, + _b: PhantomData, } - impl Operation for ReshapeDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + impl Operation + for ReshapeDimsOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_reshape::(input, Shape::from(&self.desc.out.shape)); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ReshapeDescription { input: tensor.into_description(), @@ -611,7 +640,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseFloat(BaseOperationDescription::Reshape(desc.clone())), - ReshapeDimsOps::::new(desc), + ReshapeDimsOps::::new(desc), ); out @@ -623,24 +652,27 @@ impl FloatTensorOps for Fusion { indices: IntTensor, ) -> FloatTensor { #[derive(new)] - struct GatherOps { + struct GatherOps { desc: GatherOperationDescription, + _b: PhantomData, } - impl Operation for GatherOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); + impl Operation for GatherOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::float_gather(self.desc.dim, tensor, indices); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream_1 = tensor.stream; let stream_2 = indices.stream; let shape: Vec = indices.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = GatherOperationDescription { tensor: tensor.into_description(), @@ -651,7 +683,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Gather(desc.clone())), - GatherOps::::new(desc), + GatherOps::::new(desc), ); out @@ -664,19 +696,20 @@ impl FloatTensorOps for Fusion { value: FloatTensor, ) -> FloatTensor { #[derive(new)] - struct ScatterOps { + struct ScatterOps { desc: ScatterOperationDescription, + _b: PhantomData, } - impl Operation for ScatterOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); - let value = handles.get_float_tensor(&self.desc.value); + impl Operation for ScatterOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); + let value = handles.get_float_tensor::(&self.desc.value); let output = B::float_scatter(self.desc.dim, tensor, indices, value); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -684,7 +717,9 @@ impl FloatTensorOps for Fusion { let stream_2 = indices.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ScatterOperationDescription { tensor: tensor.into_description(), @@ -697,7 +732,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2, stream_3], OperationDescription::NumericFloat(NumericOperationDescription::Scatter(desc.clone())), - ScatterOps::::new(desc), + ScatterOps::::new(desc), ); out @@ -709,18 +744,19 @@ impl FloatTensorOps for Fusion { indices: IntTensor, ) -> FloatTensor { #[derive(new)] - struct SelectOps { + struct SelectOps { desc: SelectOperationDescription, + _b: PhantomData, } - impl Operation for SelectOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); + impl Operation for SelectOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::float_select(tensor, self.desc.dim, indices); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -728,7 +764,9 @@ impl FloatTensorOps for Fusion { let stream_2 = indices.stream; let mut shape: Vec = tensor.shape.clone(); shape[dim] = indices.shape[0]; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = SelectOperationDescription { tensor: tensor.into_description(), dim, @@ -738,7 +776,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Select(desc.clone())), - SelectOps::::new(desc), + SelectOps::::new(desc), ); out @@ -751,19 +789,20 @@ impl FloatTensorOps for Fusion { value: FloatTensor, ) -> FloatTensor { #[derive(new)] - struct SelectAssignOps { + struct SelectAssignOps { desc: SelectAssignOperationDescription, + _b: PhantomData, } - impl Operation for SelectAssignOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); - let value = handles.get_float_tensor(&self.desc.value); + impl Operation for SelectAssignOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); + let value = handles.get_float_tensor::(&self.desc.value); let output = B::float_select_assign(tensor, self.desc.dim, indices, value); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -771,7 +810,9 @@ impl FloatTensorOps for Fusion { let stream_2 = indices.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = SelectAssignOperationDescription { tensor: tensor.into_description(), @@ -785,7 +826,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::SelectAssign( desc.clone(), )), - SelectAssignOps::::new(desc), + SelectAssignOps::::new(desc), ); out @@ -796,18 +837,21 @@ impl FloatTensorOps for Fusion { ranges: [Range; D2], ) -> FloatTensor { #[derive(new)] - struct SliceOps { + struct SliceOps { desc: SliceOperationDescription, + _b: PhantomData, } - impl Operation for SliceOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); + impl Operation + for SliceOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_slice::(tensor, self.desc.ranges.clone().try_into().unwrap()); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; @@ -817,7 +861,9 @@ impl FloatTensorOps for Fusion { shape.push(tensor.shape[i]); } - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = SliceOperationDescription { tensor: tensor.into_description(), @@ -827,7 +873,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseFloat(BaseOperationDescription::Slice(desc.clone())), - SliceOps::::new(desc), + SliceOps::::new(desc), ); out @@ -839,14 +885,17 @@ impl FloatTensorOps for Fusion { value: FloatTensor, ) -> FloatTensor { #[derive(new)] - struct SliceAssignOps { + struct SliceAssignOps { desc: SliceAssignOperationDescription, + _b: PhantomData, } - impl Operation for SliceAssignOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let value = handles.get_float_tensor::(&self.desc.value); + impl Operation + for SliceAssignOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let value = handles.get_float_tensor::(&self.desc.value); let output = B::float_slice_assign::( tensor, @@ -854,14 +903,16 @@ impl FloatTensorOps for Fusion { value, ); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream_1 = tensor.stream; let stream_2 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = SliceAssignOperationDescription { tensor: tensor.into_description(), @@ -872,7 +923,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseFloat(BaseOperationDescription::SliceAssign(desc.clone())), - SliceAssignOps::::new(desc), + SliceAssignOps::::new(desc), ); out @@ -884,19 +935,20 @@ impl FloatTensorOps for Fusion { value: FloatTensor, ) -> FloatTensor { #[derive(new)] - struct MaskWhereOps { + struct MaskWhereOps { desc: MaskWhereOperationDescription, + _b: PhantomData, } - impl Operation for MaskWhereOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let value = handles.get_float_tensor(&self.desc.value); - let mask = handles.get_bool_tensor(&self.desc.mask); + impl Operation for MaskWhereOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let value = handles.get_float_tensor::(&self.desc.value); + let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::float_mask_where(tensor, mask, value); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -904,7 +956,9 @@ impl FloatTensorOps for Fusion { let stream_2 = mask.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = MaskWhereOperationDescription { tensor: tensor.into_description(), @@ -917,7 +971,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::MaskWhere( desc.clone(), )), - MaskWhereOps::::new(desc), + MaskWhereOps::::new(desc), ); out @@ -929,25 +983,28 @@ impl FloatTensorOps for Fusion { value: FloatElem, ) -> FloatTensor { #[derive(new)] - struct MaskFillOps { + struct MaskFillOps { desc: MaskFillOperationDescription, + _b: PhantomData, } - impl Operation for MaskFillOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); - let mask = handles.get_bool_tensor(&self.desc.mask); + impl Operation for MaskFillOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::float_mask_fill(tensor, mask, self.desc.value.elem()); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream_1 = tensor.stream; let stream_2 = mask.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = MaskFillOperationDescription { tensor: tensor.into_description(), value: value.elem(), @@ -957,7 +1014,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::MaskFill(desc.clone())), - MaskFillOps::::new(desc), + MaskFillOps::::new(desc), ); out @@ -973,7 +1030,7 @@ impl FloatTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -983,7 +1040,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseFloat(BaseOperationDescription::Equal(desc.clone())), - EqualOps::::new(desc), + EqualOps::::new(desc), ); out @@ -996,7 +1053,9 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(EqualElemOps, B::float_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1008,7 +1067,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::EqualElem( desc.clone(), )), - EqualElemOps::::new(desc), + EqualElemOps::::new(desc), ); out @@ -1024,7 +1083,7 @@ impl FloatTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -1034,7 +1093,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Greater(desc.clone())), - GreaterOps::::new(desc), + GreaterOps::::new(desc), ); out @@ -1047,7 +1106,9 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(GreaterElemOps, B::float_greater_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1059,7 +1120,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::GreaterElem( desc.clone(), )), - GreaterElemOps::::new(desc), + GreaterElemOps::::new(desc), ); out @@ -1075,7 +1136,7 @@ impl FloatTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -1087,7 +1148,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::GreaterEqual( desc.clone(), )), - GreaterEqualOps::::new(desc), + GreaterEqualOps::::new(desc), ); out @@ -1100,7 +1161,9 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(GreaterEqualElemOps, B::float_greater_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1112,7 +1175,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::GreaterEqualElem( desc.clone(), )), - GreaterEqualElemOps::::new(desc), + GreaterEqualElemOps::::new(desc), ); out @@ -1128,7 +1191,7 @@ impl FloatTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -1138,7 +1201,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Lower(desc.clone())), - LowerOps::::new(desc), + LowerOps::::new(desc), ); out @@ -1151,7 +1214,9 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(LowerElemOps, B::float_lower_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1163,7 +1228,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::LowerElem( desc.clone(), )), - LowerElemOps::::new(desc), + LowerElemOps::::new(desc), ); out @@ -1179,7 +1244,7 @@ impl FloatTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -1191,7 +1256,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::LowerEqual( desc.clone(), )), - LowerEqualOps::::new(desc), + LowerEqualOps::::new(desc), ); out @@ -1204,7 +1269,9 @@ impl FloatTensorOps for Fusion { scalar_float_cmp_ops!(LowerEqualElemOps, B::float_lower_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1216,17 +1283,19 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::LowerEqualElem( desc.clone(), )), - LowerEqualElemOps::::new(desc), + LowerEqualElemOps::::new(desc), ); out } fn float_sum(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(SumOps, B::float_sum); + unary_float_ops!(SumOps, B::float_sum, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1235,7 +1304,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Sum(desc.clone())), - SumOps::::new(desc), + SumOps::::new(desc), ); out @@ -1250,7 +1319,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1260,17 +1331,19 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::SumDim(desc.clone())), - SumDimOps::::new(desc), + SumDimOps::::new(desc), ); out } fn float_mean(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MeanOps, B::float_mean); + unary_float_ops!(MeanOps, B::float_mean, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1279,7 +1352,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Mean(desc.clone())), - MeanOps::::new(desc), + MeanOps::::new(desc), ); out @@ -1294,7 +1367,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1304,7 +1379,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::MeanDim(desc.clone())), - MeanDimOps::::new(desc), + MeanDimOps::::new(desc), ); out @@ -1314,7 +1389,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(ExpOps, B::float_exp); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: lhs.into_description(), @@ -1323,7 +1400,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Exp(desc.clone())), - ExpOps::::new(desc), + ExpOps::::new(desc), ); out @@ -1333,7 +1410,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(LogOps, B::float_log); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1342,7 +1421,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Log(desc.clone())), - LogOps::::new(desc), + LogOps::::new(desc), ); out @@ -1352,7 +1431,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(Log1pOps, B::float_log1p); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1361,7 +1442,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Log1p(desc.clone())), - Log1pOps::::new(desc), + Log1pOps::::new(desc), ); out @@ -1374,7 +1455,9 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(PowfOps, B::float_powf_scalar, f32); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -1384,7 +1467,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::PowfScalar(desc.clone())), - PowfOps::::new(desc), + PowfOps::::new(desc), ); out @@ -1394,7 +1477,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(SqrtOps, B::float_sqrt); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1403,7 +1488,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Sqrt(desc.clone())), - SqrtOps::::new(desc), + SqrtOps::::new(desc), ); out @@ -1413,7 +1498,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(AbsOps, B::float_abs); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1422,7 +1509,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Abs(desc.clone())), - AbsOps::::new(desc), + AbsOps::::new(desc), ); out @@ -1432,7 +1519,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(CosOps, B::float_cos); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1441,7 +1530,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Cos(desc.clone())), - CosOps::::new(desc), + CosOps::::new(desc), ); out @@ -1451,7 +1540,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(SinOps, B::float_sin); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1460,7 +1551,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Sin(desc.clone())), - SinOps::::new(desc), + SinOps::::new(desc), ); out @@ -1470,7 +1561,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(TanhOps, B::float_tanh); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1479,7 +1572,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Tanh(desc.clone())), - TanhOps::::new(desc), + TanhOps::::new(desc), ); out @@ -1489,7 +1582,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(Recip, B::float_recip); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), @@ -1497,7 +1592,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Recip(desc.clone())), - Recip::::new(desc), + Recip::::new(desc), ); out @@ -1507,7 +1602,9 @@ impl FloatTensorOps for Fusion { unary_float_ops!(TanhOps, B::float_erf); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1516,7 +1613,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Float(FloatOperationDescription::Erf(desc.clone())), - TanhOps::::new(desc), + TanhOps::::new(desc), ); out @@ -1527,22 +1624,23 @@ impl FloatTensorOps for Fusion { dim: usize, ) -> FloatTensor { #[derive(new)] - struct CatOps { + struct CatOps { desc: CatOperationDescription, + _b: PhantomData, } - impl Operation for CatOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for CatOps { + fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc .tensors .iter() - .map(|tensor| handles.get_float_tensor(tensor)) + .map(|tensor| handles.get_float_tensor::(tensor)) .collect(); let output = B::float_cat::(tensors, self.desc.dim); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -1557,7 +1655,7 @@ impl FloatTensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = CatOperationDescription { tensors: tensors.into_iter().map(|t| t.into_description()).collect(), @@ -1567,7 +1665,7 @@ impl FloatTensorOps for Fusion { client.register( streams, OperationDescription::BaseFloat(BaseOperationDescription::Cat(desc.clone())), - CatOps::::new(desc), + CatOps::::new(desc), ); out @@ -1582,7 +1680,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1592,7 +1692,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::ArgMax(desc.clone())), - ArgMaxOps::::new(desc), + ArgMaxOps::::new(desc), ); out @@ -1604,24 +1704,27 @@ impl FloatTensorOps for Fusion { times: usize, ) -> FloatTensor { #[derive(new)] - struct RepeatOps { + struct RepeatOps { desc: RepeatOperationDescription, + _b: PhantomData, } - impl Operation for RepeatOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); + impl Operation for RepeatOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_repeat::(tensor, self.desc.dim, self.desc.times); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] *= times; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = RepeatOperationDescription { tensor: tensor.into_description(), @@ -1632,7 +1735,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseFloat(BaseOperationDescription::Repeat(desc.clone())), - RepeatOps::::new(desc), + RepeatOps::::new(desc), ); out @@ -1647,7 +1750,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1657,17 +1762,19 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::ArgMin(desc.clone())), - ArgMinOps::::new(desc), + ArgMinOps::::new(desc), ); out } fn float_max(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MaxOps, B::float_max); + unary_float_ops!(MaxOps, B::float_max, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1676,7 +1783,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Max(desc.clone())), - MaxOps::::new(desc), + MaxOps::::new(desc), ); out @@ -1691,7 +1798,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1701,7 +1810,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::MaxDim(desc.clone())), - MaxDimOps::::new(desc), + MaxDimOps::::new(desc), ); out @@ -1712,17 +1821,18 @@ impl FloatTensorOps for Fusion { dim: usize, ) -> (FloatTensor, IntTensor) { #[derive(new)] - struct MaxDimWithIndicesOps { + struct MaxDimWithIndicesOps { desc: ReduceDimWithIndicesDescription, + _b: PhantomData, } - impl Operation for MaxDimWithIndicesOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); + impl Operation for MaxDimWithIndicesOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); let (output, indices) = B::float_max_dim_with_indices(tensor, self.desc.dim); - handles.register_float_tensor(&self.desc.out.id, output); - handles.register_int_tensor(&self.desc.out_indices.id, indices); + handles.register_float_tensor::(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } @@ -1730,8 +1840,8 @@ impl FloatTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape.clone(), B::FloatElem::dtype()); + let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReduceDimWithIndicesDescription { tensor: tensor.into_description(), @@ -1744,17 +1854,19 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::MaxDimWithIndices( desc.clone(), )), - MaxDimWithIndicesOps::::new(desc), + MaxDimWithIndicesOps::::new(desc), ); (out, out_indices) } fn float_min(tensor: FloatTensor) -> FloatTensor { - unary_float_ops!(MinOps, B::float_min); + unary_float_ops!(MinOps, B::float_min, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1763,7 +1875,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::Min(desc.clone())), - MinOps::::new(desc), + MinOps::::new(desc), ); out @@ -1778,7 +1890,9 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1788,7 +1902,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat(NumericOperationDescription::MinDim(desc.clone())), - MinDimOps::::new(desc), + MinDimOps::::new(desc), ); out @@ -1799,17 +1913,18 @@ impl FloatTensorOps for Fusion { dim: usize, ) -> (FloatTensor, IntTensor) { #[derive(new)] - struct MinDimWithIndicesOps { + struct MinDimWithIndicesOps { desc: ReduceDimWithIndicesDescription, + _b: PhantomData, } - impl Operation for MinDimWithIndicesOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_float_tensor::(&self.desc.tensor); + impl Operation for MinDimWithIndicesOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); let (output, indices) = B::float_min_dim_with_indices(tensor, self.desc.dim); - handles.register_float_tensor(&self.desc.out.id, output); - handles.register_int_tensor(&self.desc.out_indices.id, indices); + handles.register_float_tensor::(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } @@ -1817,8 +1932,8 @@ impl FloatTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape.clone(), B::FloatElem::dtype()); + let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReduceDimWithIndicesDescription { tensor: tensor.into_description(), @@ -1831,7 +1946,7 @@ impl FloatTensorOps for Fusion { OperationDescription::NumericFloat(NumericOperationDescription::MinDimWithIndices( desc.clone(), )), - MinDimWithIndicesOps::::new(desc), + MinDimWithIndicesOps::::new(desc), ); (out, out_indices) @@ -1845,9 +1960,10 @@ impl FloatTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::FloatElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -1857,7 +1973,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericFloat(NumericOperationDescription::Powf(desc.clone())), - PowOps::::new(desc), + PowOps::::new(desc), ); out @@ -1868,16 +1984,17 @@ impl FloatTensorOps for Fusion { axes: [usize; D], ) -> FloatTensor { #[derive(new)] - struct PermuteDimsOps { + struct PermuteDimsOps { desc: PermuteOperationDescription, + _b: PhantomData, } - impl Operation for PermuteDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + impl Operation for PermuteDimsOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); let output = B::float_permute(input, axes); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } @@ -1886,7 +2003,9 @@ impl FloatTensorOps for Fusion { // Change the shape of the tensor to match the new axes let shape = axes.into_iter().map(|x| tensor.shape[x]).collect(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = PermuteOperationDescription { input: tensor.into_description(), @@ -1897,7 +2016,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Permute(desc.clone())), - PermuteDimsOps::::new(desc), + PermuteDimsOps::::new(desc), ); out @@ -1908,23 +2027,28 @@ impl FloatTensorOps for Fusion { shape: Shape, ) -> FloatTensor { #[derive(new)] - struct ExpandOps { + struct ExpandOps { desc: ExpandOperationDescription, + _b: PhantomData, } - impl Operation for ExpandOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + impl Operation + for ExpandOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); let output = B::float_expand(input, shape.into()); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(shape.dims.into()); + let out = tensor + .client + .tensor_uninitialized(shape.dims.into(), B::FloatElem::dtype()); let desc = ExpandOperationDescription { input: tensor.into_description(), @@ -1935,7 +2059,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseFloat(BaseOperationDescription::Expand(desc.clone())), - ExpandOps::::new(desc), + ExpandOps::::new(desc), ); out @@ -1946,20 +2070,23 @@ impl FloatTensorOps for Fusion { axes: &[usize], ) -> FloatTensor { #[derive(new)] - struct FlipOps { + struct FlipOps { desc: FlipOperationDescription, + _b: PhantomData, } - impl Operation for FlipOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + impl Operation for FlipOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_flip(input, &self.desc.axes); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = FlipOperationDescription { input: tensor.into_description(), @@ -1970,7 +2097,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Flip(desc.clone())), - FlipOps::::new(desc), + FlipOps::::new(desc), ); out diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index c990d3f25c..70b2287cec 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -9,20 +9,11 @@ use crate::{ }; use burn_tensor::{ ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, - repr::{ - self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, - ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - GatherOperationDescription, HandleContainer, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatOperationDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, - TensorDescription, UnaryOperationDescription, - }, - Data, Device, Distribution, ElementConversion, Reader, Shape, + repr::{self, *}, + DType, Data, Device, Distribution, Element, ElementConversion, Reader, Shape, }; use core::ops::Range; +use std::marker::PhantomData; impl IntTensorOps for Fusion { fn int_empty(shape: Shape, device: &Device) -> IntTensor { @@ -30,7 +21,12 @@ impl IntTensorOps for Fusion { let tensor = B::int_empty(shape.clone(), device); let stream = StreamId::current(); - client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into(), stream) + client.register_tensor( + B::int_tensor_handle(tensor), + shape.dims.into(), + stream, + B::IntElem::dtype(), + ) } fn int_shape(tensor: &IntTensor) -> Shape { @@ -38,7 +34,7 @@ impl IntTensorOps for Fusion { } fn int_into_data(tensor: IntTensor) -> Reader, D>> { - tensor.int_into_data() + tensor.int_into_data::() } fn int_from_data( @@ -50,7 +46,12 @@ impl IntTensorOps for Fusion { let shape = B::int_shape(&tensor); let stream = StreamId::current(); - client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into(), stream) + client.register_tensor( + B::int_tensor_handle(tensor), + shape.dims.into(), + stream, + B::IntElem::dtype(), + ) } fn int_device(tensor: &IntTensor) -> Device { @@ -72,9 +73,11 @@ impl IntTensorOps for Fusion { let client_target = get_client::(&device_target); let client_original = tensor.client.clone(); - client_original - .clone() - .change_client_int::(tensor.into_description(), client_target, id) + client_original.clone().change_client_int::( + tensor.into_description(), + client_target, + id, + ) } fn int_reshape( @@ -82,21 +85,26 @@ impl IntTensorOps for Fusion { shape: Shape, ) -> IntTensor { #[derive(new)] - struct ReshapeDimsOps { + struct ReshapeDimsOps { desc: ReshapeDescription, + _b: PhantomData, } - impl Operation for ReshapeDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + impl Operation + for ReshapeDimsOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_reshape::(input, Shape::from(&self.desc.out.shape)); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; let shape: Vec = shape.dims.into(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReshapeDescription { input: tensor.into_description(), @@ -105,7 +113,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Reshape(desc.clone())), - ReshapeDimsOps::::new(desc), + ReshapeDimsOps::::new(desc), ); out @@ -116,18 +124,21 @@ impl IntTensorOps for Fusion { ranges: [Range; D2], ) -> IntTensor { #[derive(new)] - struct SliceOps { + struct SliceOps { desc: SliceOperationDescription, + _b: PhantomData, } - impl Operation for SliceOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); + impl Operation + for SliceOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_slice::(tensor, self.desc.ranges.clone().try_into().unwrap()); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -138,7 +149,9 @@ impl IntTensorOps for Fusion { shape.push(tensor.shape[i]); } - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = SliceOperationDescription { tensor: tensor.into_description(), @@ -148,7 +161,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Slice(desc.clone())), - SliceOps::::new(desc), + SliceOps::::new(desc), ); out @@ -160,14 +173,17 @@ impl IntTensorOps for Fusion { value: IntTensor, ) -> IntTensor { #[derive(new)] - struct SliceAssignOps { + struct SliceAssignOps { desc: SliceAssignOperationDescription, + _b: PhantomData, } - impl Operation for SliceAssignOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let value = handles.get_int_tensor::(&self.desc.value); + impl Operation + for SliceAssignOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let value = handles.get_int_tensor::(&self.desc.value); let output = B::int_slice_assign::( tensor, @@ -175,14 +191,16 @@ impl IntTensorOps for Fusion { value, ); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } let stream_1 = tensor.stream; let stream_2 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = SliceAssignOperationDescription { tensor: tensor.into_description(), ranges: ranges.into(), @@ -192,7 +210,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseInt(BaseOperationDescription::SliceAssign(desc.clone())), - SliceAssignOps::::new(desc), + SliceAssignOps::::new(desc), ); out @@ -204,19 +222,20 @@ impl IntTensorOps for Fusion { value: IntTensor, ) -> IntTensor { #[derive(new)] - struct MaskWhereOps { + struct MaskWhereOps { desc: MaskWhereOperationDescription, + _b: PhantomData, } - impl Operation for MaskWhereOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let value = handles.get_int_tensor(&self.desc.value); - let mask = handles.get_bool_tensor(&self.desc.mask); + impl Operation for MaskWhereOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let value = handles.get_int_tensor::(&self.desc.value); + let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::int_mask_where(tensor, mask, value); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -224,7 +243,9 @@ impl IntTensorOps for Fusion { let stream_2 = mask.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = MaskWhereOperationDescription { tensor: tensor.into_description(), @@ -235,7 +256,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2, stream_3], OperationDescription::NumericInt(NumericOperationDescription::MaskWhere(desc.clone())), - MaskWhereOps::::new(desc), + MaskWhereOps::::new(desc), ); out @@ -247,25 +268,28 @@ impl IntTensorOps for Fusion { value: IntElem, ) -> IntTensor { #[derive(new)] - struct MaskFillOps { + struct MaskFillOps { desc: MaskFillOperationDescription, + _b: PhantomData, } - impl Operation for MaskFillOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let mask = handles.get_bool_tensor(&self.desc.mask); + impl Operation for MaskFillOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::int_mask_fill(tensor, mask, self.desc.value.elem()); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } let stream_1 = tensor.stream; let stream_2 = mask.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = MaskFillOperationDescription { tensor: tensor.into_description(), value: value.elem(), @@ -275,7 +299,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::MaskFill(desc.clone())), - MaskFillOps::::new(desc), + MaskFillOps::::new(desc), ); out @@ -287,24 +311,27 @@ impl IntTensorOps for Fusion { indices: IntTensor, ) -> IntTensor { #[derive(new)] - struct GatherOps { + struct GatherOps { desc: GatherOperationDescription, + _b: PhantomData, } - impl Operation for GatherOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); + impl Operation for GatherOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::int_gather(self.desc.dim, tensor, indices); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } let stream_1 = tensor.stream; let stream_2 = indices.stream; let shape: Vec = indices.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = GatherOperationDescription { tensor: tensor.into_description(), dim, @@ -314,7 +341,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::Gather(desc.clone())), - GatherOps::::new(desc), + GatherOps::::new(desc), ); out @@ -327,19 +354,20 @@ impl IntTensorOps for Fusion { value: IntTensor, ) -> IntTensor { #[derive(new)] - struct ScatterOps { + struct ScatterOps { desc: ScatterOperationDescription, + _b: PhantomData, } - impl Operation for ScatterOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); - let value = handles.get_int_tensor(&self.desc.value); + impl Operation for ScatterOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); + let value = handles.get_int_tensor::(&self.desc.value); let output = B::int_scatter(self.desc.dim, tensor, indices, value); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -347,7 +375,9 @@ impl IntTensorOps for Fusion { let stream_2 = indices.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScatterOperationDescription { tensor: tensor.into_description(), dim, @@ -358,7 +388,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2, stream_3], OperationDescription::NumericInt(NumericOperationDescription::Scatter(desc.clone())), - ScatterOps::::new(desc), + ScatterOps::::new(desc), ); out @@ -370,18 +400,19 @@ impl IntTensorOps for Fusion { indices: IntTensor, ) -> IntTensor { #[derive(new)] - struct SelectOps { + struct SelectOps { desc: SelectOperationDescription, + _b: PhantomData, } - impl Operation for SelectOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); + impl Operation for SelectOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::int_select(tensor, self.desc.dim, indices); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -389,7 +420,9 @@ impl IntTensorOps for Fusion { let stream_2 = indices.stream; let mut shape: Vec = tensor.shape.clone(); shape[dim] = indices.shape[0]; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = SelectOperationDescription { tensor: tensor.into_description(), dim, @@ -399,7 +432,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::Select(desc.clone())), - SelectOps::::new(desc), + SelectOps::::new(desc), ); out @@ -412,19 +445,20 @@ impl IntTensorOps for Fusion { value: IntTensor, ) -> IntTensor { #[derive(new)] - struct SelectAssignOps { + struct SelectAssignOps { desc: SelectAssignOperationDescription, + _b: PhantomData, } - impl Operation for SelectAssignOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); - let indices = handles.get_int_tensor(&self.desc.indices); - let value = handles.get_int_tensor(&self.desc.value); + impl Operation for SelectAssignOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); + let value = handles.get_int_tensor::(&self.desc.value); let output = B::int_select_assign(tensor, self.desc.dim, indices, value); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -432,7 +466,9 @@ impl IntTensorOps for Fusion { let stream_2 = indices.stream; let stream_3 = value.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = SelectAssignOperationDescription { tensor: tensor.into_description(), dim, @@ -445,7 +481,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::SelectAssign( desc.clone(), )), - SelectAssignOps::::new(desc), + SelectAssignOps::::new(desc), ); out @@ -453,22 +489,23 @@ impl IntTensorOps for Fusion { fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { #[derive(new)] - struct CatOps { + struct CatOps { desc: CatOperationDescription, + _b: PhantomData, } - impl Operation for CatOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for CatOps { + fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc .tensors .iter() - .map(|tensor| handles.get_int_tensor(tensor)) + .map(|tensor| handles.get_int_tensor::(tensor)) .collect(); let output = B::int_cat::(tensors, self.desc.dim); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -483,7 +520,7 @@ impl IntTensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = CatOperationDescription { tensors: tensors.into_iter().map(|t| t.into_description()).collect(), @@ -493,7 +530,7 @@ impl IntTensorOps for Fusion { client.register( streams, OperationDescription::BaseInt(BaseOperationDescription::Cat(desc.clone())), - CatOps::::new(desc), + CatOps::::new(desc), ); out @@ -509,7 +546,7 @@ impl IntTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -519,7 +556,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::BaseInt(BaseOperationDescription::Equal(desc.clone())), - EqualOps::::new(desc), + EqualOps::::new(desc), ); out @@ -532,7 +569,9 @@ impl IntTensorOps for Fusion { scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -542,7 +581,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::EqualElem(desc.clone())), - EqualElemOps::::new(desc), + EqualElemOps::::new(desc), ); out @@ -558,7 +597,7 @@ impl IntTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -568,7 +607,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::Greater(desc.clone())), - GreaterOps::::new(desc), + GreaterOps::::new(desc), ); out @@ -581,7 +620,9 @@ impl IntTensorOps for Fusion { scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -593,7 +634,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::GreaterElem( desc.clone(), )), - GreaterElemOps::::new(desc), + GreaterElemOps::::new(desc), ); out @@ -609,7 +650,7 @@ impl IntTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -621,7 +662,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::GreaterEqual( desc.clone(), )), - GreaterEqualOps::::new(desc), + GreaterEqualOps::::new(desc), ); out @@ -634,7 +675,9 @@ impl IntTensorOps for Fusion { scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -646,7 +689,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::GreaterEqualElem( desc.clone(), )), - GreaterEqualElemOps::::new(desc), + GreaterEqualElemOps::::new(desc), ); out @@ -662,7 +705,7 @@ impl IntTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -672,7 +715,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::Lower(desc.clone())), - LowerOps::::new(desc), + LowerOps::::new(desc), ); out @@ -685,7 +728,9 @@ impl IntTensorOps for Fusion { scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -695,7 +740,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::LowerElem(desc.clone())), - LowerElemOps::::new(desc), + LowerElemOps::::new(desc), ); out @@ -711,7 +756,7 @@ impl IntTensorOps for Fusion { let stream_2 = rhs.stream; let out = lhs .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -721,7 +766,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], OperationDescription::NumericInt(NumericOperationDescription::LowerEqual(desc.clone())), - LowerEqualOps::::new(desc), + LowerEqualOps::::new(desc), ); out @@ -734,7 +779,9 @@ impl IntTensorOps for Fusion { scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), DType::Bool); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -746,7 +793,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::LowerEqualElem( desc.clone(), )), - LowerEqualElemOps::::new(desc), + LowerEqualElemOps::::new(desc), ); out @@ -760,9 +807,10 @@ impl IntTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -772,7 +820,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], repr::OperationDescription::NumericInt(NumericOperationDescription::Add(desc.clone())), - AddOps::::new(desc), + AddOps::::new(desc), ); out @@ -785,7 +833,9 @@ impl IntTensorOps for Fusion { scalar_int_ops!(AddOps, B::int_add_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -797,7 +847,7 @@ impl IntTensorOps for Fusion { repr::OperationDescription::NumericInt(NumericOperationDescription::AddScalar( desc.clone(), )), - AddOps::::new(desc), + AddOps::::new(desc), ); out @@ -811,9 +861,10 @@ impl IntTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -823,7 +874,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], repr::OperationDescription::NumericInt(NumericOperationDescription::Sub(desc.clone())), - SubOps::::new(desc), + SubOps::::new(desc), ); out @@ -836,7 +887,9 @@ impl IntTensorOps for Fusion { scalar_int_ops!(SubOps, B::int_sub_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -848,7 +901,7 @@ impl IntTensorOps for Fusion { repr::OperationDescription::NumericInt(NumericOperationDescription::SubScalar( desc.clone(), )), - SubOps::::new(desc), + SubOps::::new(desc), ); out @@ -862,9 +915,10 @@ impl IntTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -874,7 +928,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], repr::OperationDescription::NumericInt(NumericOperationDescription::Mul(desc.clone())), - MulOps::::new(desc), + MulOps::::new(desc), ); out @@ -887,7 +941,9 @@ impl IntTensorOps for Fusion { scalar_int_ops!(MulOps, B::int_mul_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -899,7 +955,7 @@ impl IntTensorOps for Fusion { repr::OperationDescription::NumericInt(NumericOperationDescription::MulScalar( desc.clone(), )), - MulOps::::new(desc), + MulOps::::new(desc), ); out @@ -913,9 +969,10 @@ impl IntTensorOps for Fusion { let stream_1 = lhs.stream; let stream_2 = rhs.stream; - let out = lhs - .client - .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); let desc = BinaryOperationDescription { lhs: lhs.into_description(), @@ -925,7 +982,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], repr::OperationDescription::NumericInt(NumericOperationDescription::Div(desc.clone())), - DivOps::::new(desc), + DivOps::::new(desc), ); out @@ -938,7 +995,9 @@ impl IntTensorOps for Fusion { scalar_int_ops!(DivOps, B::int_div_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -950,7 +1009,7 @@ impl IntTensorOps for Fusion { repr::OperationDescription::NumericInt(NumericOperationDescription::DivScalar( desc.clone(), )), - DivOps::::new(desc), + DivOps::::new(desc), ); out @@ -963,7 +1022,9 @@ impl IntTensorOps for Fusion { scalar_int_ops!(ModOps, B::int_remainder_scalar); let stream = lhs.stream; - let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -975,7 +1036,7 @@ impl IntTensorOps for Fusion { repr::OperationDescription::NumericInt(NumericOperationDescription::RemScalar( desc.clone(), )), - ModOps::::new(desc), + ModOps::::new(desc), ); out @@ -983,27 +1044,28 @@ impl IntTensorOps for Fusion { fn int_zeros(shape: Shape, device: &Device) -> IntTensor { #[derive(new)] - struct ZerosOps { + struct ZerosOps { desc: TensorDescription, + device: Device, } - impl Operation for ZerosOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for ZerosOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.shape.clone()); - let output = B::int_zeros::(shape, &handles.device); - handles.register_int_tensor(&self.desc.id, output); + let output = B::int_zeros::(shape, &self.device); + handles.register_int_tensor::(&self.desc.id, output); } } let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = out.to_description_out(); client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Zeros(desc.clone())), - ZerosOps::::new(desc), + ZerosOps::::new(desc, device.clone()), ); out @@ -1011,38 +1073,41 @@ impl IntTensorOps for Fusion { fn int_ones(shape: Shape, device: &Device) -> IntTensor { #[derive(new)] - struct OnesOps { + struct OnesOps { desc: TensorDescription, + device: Device, } - impl Operation for OnesOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for OnesOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.shape.clone()); - let output = B::int_ones::(shape, &handles.device); - handles.register_int_tensor(&self.desc.id, output); + let output = B::int_ones::(shape, &self.device); + handles.register_int_tensor::(&self.desc.id, output); } } let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = out.to_description_out(); client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Ones(desc.clone())), - OnesOps::::new(desc), + OnesOps::::new(desc, device.clone()), ); out } fn int_sum(tensor: IntTensor) -> IntTensor { - unary_int_ops!(SumOps, B::int_sum); + unary_int_ops!(SumOps, B::int_sum, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1051,7 +1116,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Sum(desc.clone())), - SumOps::::new(desc), + SumOps::::new(desc), ); out @@ -1063,7 +1128,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1073,17 +1140,19 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::SumDim(desc.clone())), - SumDimOps::::new(desc), + SumDimOps::::new(desc), ); out } fn int_prod(tensor: IntTensor) -> IntTensor { - unary_int_ops!(ProdOps, B::int_prod); + unary_int_ops!(ProdOps, B::int_prod, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1092,7 +1161,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Prod(desc.clone())), - ProdOps::::new(desc), + ProdOps::::new(desc), ); out @@ -1104,7 +1173,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1114,17 +1185,19 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::ProdDim(desc.clone())), - ProdDimOps::::new(desc), + ProdDimOps::::new(desc), ); out } fn int_mean(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MeanOps, B::int_mean); + unary_int_ops!(MeanOps, B::int_mean, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1133,7 +1206,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Mean(desc.clone())), - MeanOps::::new(desc), + MeanOps::::new(desc), ); out @@ -1145,7 +1218,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1155,7 +1230,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::MeanDim(desc.clone())), - MeanDimOps::::new(desc), + MeanDimOps::::new(desc), ); out @@ -1167,7 +1242,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1177,7 +1254,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::ArgMax(desc.clone())), - ArgMaxOps::::new(desc), + ArgMaxOps::::new(desc), ); out @@ -1189,7 +1266,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1199,7 +1278,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::ArgMin(desc.clone())), - ArgMinOps::::new(desc), + ArgMinOps::::new(desc), ); out @@ -1211,21 +1290,24 @@ impl IntTensorOps for Fusion { max: IntElem, ) -> IntTensor { #[derive(new)] - struct ClampOps { + struct ClampOps { desc: ClampOperationDescription, + _b: PhantomData, } - impl Operation for ClampOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.tensor); + impl Operation for ClampOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_clamp(input, self.desc.min.elem(), self.desc.max.elem()); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); let desc = ClampOperationDescription { tensor: tensor.into_description(), min: min.elem(), @@ -1235,7 +1317,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Clamp(desc.clone())), - ClampOps::::new(desc), + ClampOps::::new(desc), ); out @@ -1245,7 +1327,9 @@ impl IntTensorOps for Fusion { unary_int_ops!(AbsOps, B::int_abs); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1254,7 +1338,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Abs(desc.clone())), - AbsOps::::new(desc), + AbsOps::::new(desc), ); out @@ -1262,20 +1346,23 @@ impl IntTensorOps for Fusion { fn int_into_float(tensor: IntTensor) -> FloatTensor { #[derive(new)] - struct IntoFloatOps { + struct IntoFloatOps { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for IntoFloatOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + impl Operation for IntoFloatOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_into_float(input); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), @@ -1283,7 +1370,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::Int(repr::IntOperationDescription::IntoFloat(desc.clone())), - IntoFloatOps::::new(desc), + IntoFloatOps::::new(desc), ); out @@ -1295,15 +1382,16 @@ impl IntTensorOps for Fusion { dim2: usize, ) -> IntTensor { #[derive(new)] - struct SwapDimsOps { + struct SwapDimsOps { desc: SwapDimsDescription, + _b: PhantomData, } - impl Operation for SwapDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + impl Operation for SwapDimsOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -1312,7 +1400,9 @@ impl IntTensorOps for Fusion { shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = SwapDimsDescription { input: tensor.into_description(), @@ -1323,17 +1413,19 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::SwapDims(desc.clone())), - SwapDimsOps::::new(desc), + SwapDimsOps::::new(desc), ); out } fn int_max(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MaxOps, B::int_max); + unary_int_ops!(MaxOps, B::int_max, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1342,7 +1434,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Max(desc.clone())), - MaxOps::::new(desc), + MaxOps::::new(desc), ); out @@ -1354,7 +1446,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1364,7 +1458,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::MaxDim(desc.clone())), - MaxDimOps::::new(desc), + MaxDimOps::::new(desc), ); out @@ -1375,17 +1469,18 @@ impl IntTensorOps for Fusion { dim: usize, ) -> (IntTensor, IntTensor) { #[derive(new)] - struct MaxDimWithIndicesOps { + struct MaxDimWithIndicesOps { desc: ReduceDimWithIndicesDescription, + _b: PhantomData, } - impl Operation for MaxDimWithIndicesOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); + impl Operation for MaxDimWithIndicesOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim); - handles.register_int_tensor(&self.desc.out.id, output); - handles.register_int_tensor(&self.desc.out_indices.id, indices); + handles.register_int_tensor::(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } @@ -1393,8 +1488,8 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape.clone(), B::IntElem::dtype()); + let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReduceDimWithIndicesDescription { tensor: tensor.into_description(), dim, @@ -1406,17 +1501,19 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::MaxDimWithIndices( desc.clone(), )), - MaxDimWithIndicesOps::::new(desc), + MaxDimWithIndicesOps::::new(desc), ); (out, out_indices) } fn int_min(tensor: IntTensor) -> IntTensor { - unary_int_ops!(MinOps, B::int_min); + unary_int_ops!(MinOps, B::int_min, reduce); let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(vec![1]); + let out = tensor + .client + .tensor_uninitialized(vec![1], B::IntElem::dtype()); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1425,7 +1522,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::Min(desc.clone())), - MinOps::::new(desc), + MinOps::::new(desc), ); out @@ -1437,7 +1534,9 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1447,7 +1546,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::MinDim(desc.clone())), - MinDimOps::::new(desc), + MinDimOps::::new(desc), ); out @@ -1458,17 +1557,18 @@ impl IntTensorOps for Fusion { dim: usize, ) -> (IntTensor, IntTensor) { #[derive(new)] - struct MinDimWithIndicesOps { + struct MinDimWithIndicesOps { desc: ReduceDimWithIndicesDescription, + _b: PhantomData, } - impl Operation for MinDimWithIndicesOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); + impl Operation for MinDimWithIndicesOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim); - handles.register_int_tensor(&self.desc.out.id, output); - handles.register_int_tensor(&self.desc.out_indices.id, indices); + handles.register_int_tensor::(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } @@ -1476,8 +1576,8 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.tensor_uninitialized(shape.clone()); - let out_indices = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape.clone(), B::IntElem::dtype()); + let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = ReduceDimWithIndicesDescription { tensor: tensor.into_description(), dim, @@ -1489,7 +1589,7 @@ impl IntTensorOps for Fusion { OperationDescription::NumericInt(NumericOperationDescription::MinDimWithIndices( desc.clone(), )), - MinDimWithIndicesOps::::new(desc), + MinDimWithIndicesOps::::new(desc), ); (out, out_indices) @@ -1501,23 +1601,24 @@ impl IntTensorOps for Fusion { device: &Device, ) -> IntTensor { #[derive(new)] - struct IntRandomOps { + struct IntRandomOps { desc: RandomOperationDescription, + device: Device, } - impl Operation for IntRandomOps { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for IntRandomOps { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.out.shape.clone()); let output: B::IntTensorPrimitive = - B::int_random(shape, self.desc.distribution, &handles.device); - handles.register_int_tensor(&self.desc.out.id, output); + B::int_random(shape, self.desc.distribution, &self.device); + handles.register_int_tensor::(&self.desc.out.id, output); } } let stream = StreamId::current(); let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(shape); + let out = client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = RandomOperationDescription { out: out.to_description_out(), @@ -1526,7 +1627,7 @@ impl IntTensorOps for Fusion { client.register( vec![stream], OperationDescription::NumericInt(NumericOperationDescription::IntRandom(desc.clone())), - IntRandomOps::::new(desc), + IntRandomOps::::new(desc, device.clone()), ); out @@ -1537,16 +1638,17 @@ impl IntTensorOps for Fusion { axes: [usize; D], ) -> IntTensor { #[derive(new)] - struct PermuteDimsOps { + struct PermuteDimsOps { desc: PermuteOperationDescription, + _b: PhantomData, } - impl Operation for PermuteDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + impl Operation for PermuteDimsOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); let output = B::int_permute(input, axes); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } @@ -1555,7 +1657,9 @@ impl IntTensorOps for Fusion { // Change the shape of the tensor to match the new axes let shape = axes.into_iter().map(|x| tensor.shape[x]).collect(); - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = PermuteOperationDescription { input: tensor.into_description(), @@ -1566,7 +1670,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Permute(desc.clone())), - PermuteDimsOps::::new(desc), + PermuteDimsOps::::new(desc), ); out @@ -1577,22 +1681,27 @@ impl IntTensorOps for Fusion { shape: Shape, ) -> IntTensor { #[derive(new)] - struct ExpandOps { + struct ExpandOps { desc: ExpandOperationDescription, + _b: PhantomData, } - impl Operation for ExpandOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_bool_tensor::(&self.desc.input); + impl Operation + for ExpandOps + { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); let output = B::bool_expand(input, shape.into()); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(shape.dims.into()); + let out = tensor + .client + .tensor_uninitialized(shape.dims.into(), B::IntElem::dtype()); let desc = ExpandOperationDescription { input: tensor.into_description(), @@ -1603,7 +1712,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Expand(desc.clone())), - ExpandOps::::new(desc), + ExpandOps::::new(desc), ); out @@ -1611,22 +1720,25 @@ impl IntTensorOps for Fusion { fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { #[derive(new)] - struct FlipDimsOps { + struct FlipDimsOps { desc: FlipOperationDescription, + _b: PhantomData, } - impl Operation for FlipDimsOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + impl Operation for FlipDimsOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let axes = &self.desc.axes; let output = B::int_flip(input, axes); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); let desc = FlipOperationDescription { input: tensor.into_description(), @@ -1637,7 +1749,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Flip(desc.clone())), - FlipDimsOps::::new(desc), + FlipDimsOps::::new(desc), ); out @@ -1649,24 +1761,27 @@ impl IntTensorOps for Fusion { times: usize, ) -> IntTensor { #[derive(new)] - struct RepeatOps { + struct RepeatOps { desc: RepeatOperationDescription, + _b: PhantomData, } - impl Operation for RepeatOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_int_tensor::(&self.desc.tensor); + impl Operation for RepeatOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_repeat::(tensor, self.desc.dim, self.desc.times); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } let stream = tensor.stream; let mut shape = tensor.shape.clone(); shape[dim] *= times; - let out = tensor.client.tensor_uninitialized(shape); + let out = tensor + .client + .tensor_uninitialized(shape, B::IntElem::dtype()); let desc = RepeatOperationDescription { tensor: tensor.into_description(), @@ -1677,7 +1792,7 @@ impl IntTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::BaseInt(BaseOperationDescription::Repeat(desc.clone())), - RepeatOps::::new(desc), + RepeatOps::::new(desc), ); out diff --git a/crates/burn-fusion/src/ops/module.rs b/crates/burn-fusion/src/ops/module.rs index a2543559c8..cb0003d5b2 100644 --- a/crates/burn-fusion/src/ops/module.rs +++ b/crates/burn-fusion/src/ops/module.rs @@ -9,28 +9,21 @@ use burn_tensor::{ MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }, - repr::{ - AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription, - AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription, - AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription, - AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, - ConvTranspose2dDescription, HandleContainer, InterpolateBackwardDescription, - InterpolateDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, - MaxPool1dWithIndicesDescription, MaxPool2dDescription, - MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, - ModuleOperationDescription, OperationDescription, - }, + repr::*, + Element, }; +use std::marker::PhantomData; macro_rules! make_ops { ($name:ident, $desc:ty, $fn:expr) => { #[derive(new)] - struct $name { + struct $name { desc: $desc, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { #[allow(clippy::redundant_closure_call)] $fn(self.desc, handles) } @@ -48,15 +41,15 @@ impl ModuleOps> for Fusion { make_ops!( Conv1dOps, Conv1dDescription, - |desc: Conv1dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&desc.x); - let weight = handles.get_float_tensor(&desc.weight); + |desc: Conv1dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&desc.x); + let weight = handles.get_float_tensor::(&desc.weight); let bias = desc .bias .as_ref() - .map(|bias| handles.get_float_tensor(bias)); + .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv1d(x, weight, bias, desc.options.into()); - handles.register_float_tensor(&desc.out.id, output); + handles.register_float_tensor::(&desc.out.id, output); } ); @@ -72,7 +65,7 @@ impl ModuleOps> for Fusion { let stream_2 = weight.stream; let stream_3 = bias.as_ref().map(|b| b.stream); let shape = vec![x.shape[0], weight.shape[0], size]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let description = Conv1dDescription { x: x.into_description(), @@ -89,7 +82,7 @@ impl ModuleOps> for Fusion { out.client.clone().register( streams, OperationDescription::Module(ModuleOperationDescription::Conv1d(description.clone())), - Conv1dOps::new(description), + Conv1dOps::::new(description), ); out @@ -104,17 +97,17 @@ impl ModuleOps> for Fusion { make_ops!( Conv2dOps, Conv2dDescription, - |args: Conv2dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); + |args: Conv2dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let weight = handles.get_float_tensor::(&args.weight); let bias = args .bias .as_ref() - .map(|bias| handles.get_float_tensor(bias)); + .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv2d(x, weight, bias, args.options.clone().into()); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -137,7 +130,7 @@ impl ModuleOps> for Fusion { let stream_2 = weight.stream; let stream_3 = bias.as_ref().map(|b| b.stream); let shape = vec![x.shape[0], weight.shape[0], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = Conv2dDescription { x: x.into_description(), @@ -154,7 +147,7 @@ impl ModuleOps> for Fusion { out.client.register( streams, OperationDescription::Module(ModuleOperationDescription::Conv2d(desc.clone())), - Conv2dOps::new(desc), + Conv2dOps::::new(desc), ); out @@ -169,17 +162,17 @@ impl ModuleOps> for Fusion { make_ops!( ConvTranspose1dOps, ConvTranspose1dDescription, - |args: ConvTranspose1dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); + |args: ConvTranspose1dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let weight = handles.get_float_tensor::(&args.weight); let bias = args .bias .as_ref() - .map(|bias| handles.get_float_tensor(bias)); + .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv_transpose1d(x, weight, bias, args.options.clone().into()); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -196,7 +189,7 @@ impl ModuleOps> for Fusion { let stream_2 = weight.stream; let stream_3 = bias.as_ref().map(|b| b.stream); let shape = vec![x.shape[0], weight.shape[1] * options.groups, size]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ConvTranspose1dDescription { x: x.into_description(), @@ -213,7 +206,7 @@ impl ModuleOps> for Fusion { out.client.register( streams, OperationDescription::Module(ModuleOperationDescription::ConvTranspose1d(desc.clone())), - ConvTranspose1dOps::new(desc), + ConvTranspose1dOps::::new(desc), ); out @@ -228,17 +221,17 @@ impl ModuleOps> for Fusion { make_ops!( ConvTranspose2dOps, ConvTranspose2dDescription, - |args: ConvTranspose2dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let weight = handles.get_float_tensor(&args.weight); + |args: ConvTranspose2dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let weight = handles.get_float_tensor::(&args.weight); let bias = args .bias .as_ref() - .map(|bias| handles.get_float_tensor(bias)); + .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv_transpose2d(x, weight, bias, args.options.clone().into()); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -263,7 +256,7 @@ impl ModuleOps> for Fusion { let stream_2 = weight.stream; let stream_3 = bias.as_ref().map(|b| b.stream); let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = ConvTranspose2dDescription { x: x.into_description(), @@ -280,7 +273,7 @@ impl ModuleOps> for Fusion { out.client.register( streams, OperationDescription::Module(ModuleOperationDescription::ConvTranspose2d(desc.clone())), - ConvTranspose2dOps::new(desc), + ConvTranspose2dOps::::new(desc), ); out @@ -296,8 +289,8 @@ impl ModuleOps> for Fusion { make_ops!( AvgPool1dOps, AvgPool1dDescription, - |args: AvgPool1dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: AvgPool1dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::avg_pool1d( x, args.kernel_size, @@ -306,14 +299,14 @@ impl ModuleOps> for Fusion { args.count_include_pad, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); let stream = x.stream; let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]); let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = AvgPool1dDescription { x: x.into_description(), @@ -326,7 +319,7 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream], OperationDescription::Module(ModuleOperationDescription::AvgPool1d(desc.clone())), - AvgPool1dOps::new(desc), + AvgPool1dOps::::new(desc), ); out @@ -342,8 +335,8 @@ impl ModuleOps> for Fusion { make_ops!( AvgPool2dOps, AvgPool2dDescription, - |args: AvgPool2dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: AvgPool2dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::avg_pool2d( x, args.kernel_size, @@ -352,7 +345,7 @@ impl ModuleOps> for Fusion { args.count_include_pad, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -363,7 +356,7 @@ impl ModuleOps> for Fusion { let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = AvgPool2dDescription { x: x.into_description(), @@ -376,7 +369,7 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream], OperationDescription::Module(ModuleOperationDescription::AvgPool2d(desc.clone())), - AvgPool2dOps::new(desc), + AvgPool2dOps::::new(desc), ); out @@ -393,9 +386,9 @@ impl ModuleOps> for Fusion { make_ops!( AvgPool1dBackwardOps, AvgPool1dBackwardDescription, - |args: AvgPool1dBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); + |args: AvgPool1dBackwardDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); let output = B::avg_pool1d_backward( x, grad, @@ -405,13 +398,15 @@ impl ModuleOps> for Fusion { args.count_include_pad, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); let stream_1 = x.stream; let stream_2 = grad.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = AvgPool1dBackwardDescription { x: x.into_description(), @@ -427,7 +422,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AvgPool1dBackward( desc.clone(), )), - AvgPool1dBackwardOps::new(desc), + AvgPool1dBackwardOps::::new(desc), ); out @@ -444,9 +439,9 @@ impl ModuleOps> for Fusion { make_ops!( AvgPool2dBackwardOps, AvgPool2dBackwardDescription, - |args: AvgPool2dBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); + |args: AvgPool2dBackwardDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); let output = B::avg_pool2d_backward( x, grad, @@ -456,13 +451,15 @@ impl ModuleOps> for Fusion { args.count_include_pad, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); let stream_1 = x.stream; let stream_2 = grad.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = AvgPool2dBackwardDescription { x: x.into_description(), @@ -478,7 +475,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AvgPool2dBackward( desc.clone(), )), - AvgPool2dBackwardOps::new(desc), + AvgPool2dBackwardOps::::new(desc), ); out @@ -494,8 +491,8 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool1dOps, MaxPool1dDescription, - |args: MaxPool1dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: MaxPool1dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::max_pool1d( x, args.kernel_size, @@ -504,7 +501,7 @@ impl ModuleOps> for Fusion { args.dilation, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -512,7 +509,7 @@ impl ModuleOps> for Fusion { let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = MaxPool1dDescription { x: x.into_description(), @@ -525,7 +522,7 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream], OperationDescription::Module(ModuleOperationDescription::MaxPool1d(desc.clone())), - MaxPool1dOps::new(desc), + MaxPool1dOps::::new(desc), ); out @@ -541,8 +538,8 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool2dOps, MaxPool2dDescription, - |args: MaxPool2dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: MaxPool2dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::max_pool2d( x, args.kernel_size, @@ -551,7 +548,7 @@ impl ModuleOps> for Fusion { args.dilation, ); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); @@ -572,7 +569,7 @@ impl ModuleOps> for Fusion { let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = MaxPool2dDescription { x: x.into_description(), @@ -585,7 +582,7 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream], OperationDescription::Module(ModuleOperationDescription::MaxPool2d(desc.clone())), - MaxPool2dOps::new(desc), + MaxPool2dOps::::new(desc), ); out @@ -601,8 +598,8 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool1dWithIndicesOps, MaxPool1dWithIndicesDescription, - |args: MaxPool1dWithIndicesDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: MaxPool1dWithIndicesDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::max_pool1d_with_indices( x, args.kernel_size, @@ -611,16 +608,18 @@ impl ModuleOps> for Fusion { args.dilation, ); - handles.register_float_tensor(&args.out.id, output.output); - handles.register_int_tensor(&args.out_indices.id, output.indices); + handles.register_float_tensor::(&args.out.id, output.output); + handles.register_int_tensor::(&args.out_indices.id, output.indices); } ); let stream = x.stream; let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.tensor_uninitialized(shape.clone()); - let out_indices = x.client.tensor_uninitialized(shape); + let out = x + .client + .tensor_uninitialized(shape.clone(), B::FloatElem::dtype()); + let out_indices = x.client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = MaxPool1dWithIndicesDescription { x: x.into_description(), @@ -636,7 +635,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndices( desc.clone(), )), - MaxPool1dWithIndicesOps::new(desc), + MaxPool1dWithIndicesOps::::new(desc), ); MaxPool1dWithIndices::new(out, out_indices) @@ -652,8 +651,8 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool2dWithIndicesOps, MaxPool2dWithIndicesDescription, - |args: MaxPool2dWithIndicesDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: MaxPool2dWithIndicesDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::max_pool2d_with_indices( x, args.kernel_size, @@ -662,8 +661,8 @@ impl ModuleOps> for Fusion { args.dilation, ); - handles.register_float_tensor(&args.out.id, output.output); - handles.register_int_tensor(&args.out_indices.id, output.indices); + handles.register_float_tensor::(&args.out.id, output.output); + handles.register_int_tensor::(&args.out_indices.id, output.indices); } ); @@ -684,8 +683,10 @@ impl ModuleOps> for Fusion { let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.tensor_uninitialized(shape.clone()); - let out_indices = x.client.tensor_uninitialized(shape); + let out = x + .client + .tensor_uninitialized(shape.clone(), B::FloatElem::dtype()); + let out_indices = x.client.tensor_uninitialized(shape, B::IntElem::dtype()); let desc = MaxPool2dWithIndicesDescription { x: x.into_description(), @@ -701,7 +702,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndices( desc.clone(), )), - MaxPool2dWithIndicesOps::new(desc), + MaxPool2dWithIndicesOps::::new(desc), ); MaxPool2dWithIndices::new(out, out_indices) @@ -719,10 +720,11 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool1dWithIndicesBackwardOps, MaxPool1dWithIndicesBackwardDescription, - |args: MaxPool1dWithIndicesBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let indices = handles.get_int_tensor(&args.indices); + |args: MaxPool1dWithIndicesBackwardDescription, + handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); + let indices = handles.get_int_tensor::(&args.indices); let output = B::max_pool1d_with_indices_backward( x, args.kernel_size, @@ -733,14 +735,16 @@ impl ModuleOps> for Fusion { indices, ); - handles.register_float_tensor(&args.out.id, output.x_grad); + handles.register_float_tensor::(&args.out.id, output.x_grad); } ); let stream_1 = x.stream; let stream_2 = output_grad.stream; let stream_3 = indices.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = MaxPool1dWithIndicesBackwardDescription { x: x.into_description(), @@ -757,7 +761,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndicesBackward( desc.clone(), )), - MaxPool1dWithIndicesBackwardOps::new(desc), + MaxPool1dWithIndicesBackwardOps::::new(desc), ); MaxPool1dBackward::new(out) @@ -775,10 +779,11 @@ impl ModuleOps> for Fusion { make_ops!( MaxPool2dWithIndicesBackwardOps, MaxPool2dWithIndicesBackwardDescription, - |args: MaxPool2dWithIndicesBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); - let indices = handles.get_int_tensor(&args.indices); + |args: MaxPool2dWithIndicesBackwardDescription, + handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); + let indices = handles.get_int_tensor::(&args.indices); let output = B::max_pool2d_with_indices_backward( x, args.kernel_size, @@ -789,14 +794,16 @@ impl ModuleOps> for Fusion { indices, ); - handles.register_float_tensor(&args.out.id, output.x_grad); + handles.register_float_tensor::(&args.out.id, output.x_grad); } ); let stream_1 = x.stream; let stream_2 = output_grad.stream; let stream_3 = indices.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = MaxPool2dWithIndicesBackwardDescription { x: x.into_description(), @@ -813,7 +820,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndicesBackward( desc.clone(), )), - MaxPool2dWithIndicesBackwardOps::new(desc), + MaxPool2dWithIndicesBackwardOps::::new(desc), ); MaxPool2dBackward::new(out) @@ -823,17 +830,17 @@ impl ModuleOps> for Fusion { make_ops!( AdaptiveAvgPool1dOps, AdaptiveAvgPool1dDescription, - |args: AdaptiveAvgPool1dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: AdaptiveAvgPool1dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::adaptive_avg_pool1d(x, args.output_size); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], output_size]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = AdaptiveAvgPool1dDescription { x: x.into_description(), @@ -845,7 +852,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1d( desc.clone(), )), - AdaptiveAvgPool1dOps::new(desc), + AdaptiveAvgPool1dOps::::new(desc), ); out @@ -858,17 +865,17 @@ impl ModuleOps> for Fusion { make_ops!( AdaptiveAvgPool2dOps, AdaptiveAvgPool2dDescription, - |args: AdaptiveAvgPool2dDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: AdaptiveAvgPool2dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::adaptive_avg_pool2d(x, args.output_size); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = AdaptiveAvgPool2dDescription { x: x.into_description(), @@ -880,7 +887,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2d( desc.clone(), )), - AdaptiveAvgPool2dOps::new(desc), + AdaptiveAvgPool2dOps::::new(desc), ); out @@ -893,18 +900,21 @@ impl ModuleOps> for Fusion { make_ops!( AdaptiveAvgPool1dBackwardOps, AdaptiveAvgPool1dBackwardDescription, - |args: AdaptiveAvgPool1dBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); + |args: AdaptiveAvgPool1dBackwardDescription, + handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); let output = B::adaptive_avg_pool1d_backward(x, grad); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); let stream_1 = x.stream; let stream_2 = grad.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = AdaptiveAvgPool1dBackwardDescription { x: x.into_description(), grad: grad.into_description(), @@ -916,7 +926,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1dBackward( desc.clone(), )), - AdaptiveAvgPool1dBackwardOps::new(desc), + AdaptiveAvgPool1dBackwardOps::::new(desc), ); out @@ -929,18 +939,21 @@ impl ModuleOps> for Fusion { make_ops!( AdaptiveAvgPool2dBackwardOps, AdaptiveAvgPool2dBackwardDescription, - |args: AdaptiveAvgPool2dBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); + |args: AdaptiveAvgPool2dBackwardDescription, + handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); let output = B::adaptive_avg_pool2d_backward(x, grad); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); let stream_1 = x.stream; let stream_2 = grad.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = AdaptiveAvgPool2dBackwardDescription { x: x.into_description(), @@ -952,7 +965,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2dBackward( desc.clone(), )), - AdaptiveAvgPool2dBackwardOps::new(desc), + AdaptiveAvgPool2dBackwardOps::::new(desc), ); out @@ -966,16 +979,16 @@ impl ModuleOps> for Fusion { make_ops!( InterpolateOps, InterpolateDescription, - |args: InterpolateDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); + |args: InterpolateDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); let output = B::interpolate(x, args.output_size, args.options.clone().into()); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); let stream = x.stream; let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]]; - let out = x.client.tensor_uninitialized(shape); + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); let desc = InterpolateDescription { x: x.into_description(), @@ -987,7 +1000,7 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream], OperationDescription::Module(ModuleOperationDescription::Interpolate(desc.clone())), - InterpolateOps::new(desc), + InterpolateOps::::new(desc), ); out @@ -1002,19 +1015,21 @@ impl ModuleOps> for Fusion { make_ops!( InterpolateBackwardOps, InterpolateBackwardDescription, - |args: InterpolateBackwardDescription, handles: &mut HandleContainer| { - let x = handles.get_float_tensor(&args.x); - let grad = handles.get_float_tensor(&args.grad); + |args: InterpolateBackwardDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let grad = handles.get_float_tensor::(&args.grad); let output = B::interpolate_backward(x, grad, args.output_size, args.options.clone().into()); - handles.register_float_tensor(&args.out.id, output); + handles.register_float_tensor::(&args.out.id, output); } ); let stream_1 = x.stream; let stream_2 = grad.stream; - let out = x.client.tensor_uninitialized(x.shape.clone()); + let out = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); let desc = InterpolateBackwardDescription { x: x.into_description(), @@ -1028,7 +1043,7 @@ impl ModuleOps> for Fusion { OperationDescription::Module(ModuleOperationDescription::InterpolateBackward( desc.clone(), )), - InterpolateBackwardOps::new(desc), + InterpolateBackwardOps::::new(desc), ); out } diff --git a/crates/burn-fusion/src/ops/unary.rs b/crates/burn-fusion/src/ops/unary.rs index d6f0833c64..0120b77948 100644 --- a/crates/burn-fusion/src/ops/unary.rs +++ b/crates/burn-fusion/src/ops/unary.rs @@ -13,16 +13,17 @@ macro_rules! scalar_float_ops { $elem:ty ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription<$elem>, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } }; @@ -33,16 +34,17 @@ macro_rules! scalar_float_ops { noconvert ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription<$elem>, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); } } }; @@ -57,16 +59,17 @@ macro_rules! scalar_float2int_ops { $elem:ty ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription<$elem>, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs.clone()); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } }; @@ -80,16 +83,37 @@ macro_rules! unary_float_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_float_tensor::(&self.desc.input); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); let output = $ops(input); - handles.register_float_tensor(&self.desc.out.id, output); + handles.register_float_tensor::(&self.desc.out.id, output); + } + } + }; + ( + $name:ident, + $ops:expr, + reduce + ) => { + #[derive(new)] + struct $name { + desc: UnaryOperationDescription, + _b: PhantomData, + } + + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); + let output = $ops(input); + + handles.register_float_tensor::(&self.desc.out.id, output); } } }; @@ -103,16 +127,37 @@ macro_rules! unary_int_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { + desc: UnaryOperationDescription, + _b: PhantomData, + } + + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); + let output = $ops(input); + + handles.register_int_tensor::(&self.desc.out.id, output); + } + } + }; + ( + $name:ident, + $ops:expr, + reduce + ) => { + #[derive(new)] + struct $name { desc: UnaryOperationDescription, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let input = handles.get_int_tensor::(&self.desc.input); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); let output = $ops(input); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } }; @@ -126,16 +171,17 @@ macro_rules! scalar_float_cmp_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_float_tensor::(&self.desc.lhs); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } }; @@ -149,16 +195,17 @@ macro_rules! scalar_int_cmp_ops { $ops:expr ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_int_tensor::(&self.desc.lhs); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); - handles.register_bool_tensor(&self.desc.out.id, output); + handles.register_bool_tensor::(&self.desc.out.id, output); } } }; @@ -179,16 +226,17 @@ macro_rules! scalar_int_ops { $elem:ty ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription<$elem>, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_int_tensor::(&self.desc.lhs); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } }; @@ -199,16 +247,17 @@ macro_rules! scalar_int_ops { noconvert ) => { #[derive(new)] - struct $name { + struct $name { desc: ScalarOperationDescription<$elem>, + _b: PhantomData, } - impl Operation for $name { - fn execute(self: Box, handles: &mut HandleContainer) { - let lhs = handles.get_int_tensor::(&self.desc.lhs); + impl Operation for $name { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs); - handles.register_int_tensor(&self.desc.out.id, output); + handles.register_int_tensor::(&self.desc.out.id, output); } } }; diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index 4c6ec62fc7..1681d6a811 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -1,6 +1,6 @@ use crate::{ stream::{execution::Operation, MultiStream, StreamId}, - FusionBackend, + FusionBackend, FusionRuntime, }; use burn_tensor::{ ops::{FloatElem, IntElem}, @@ -8,23 +8,20 @@ use burn_tensor::{ }; use std::sync::Arc; -pub struct FusionServer -where - B: FusionBackend, -{ - streams: MultiStream, - pub(crate) handles: HandleContainer, - pub device: B::Device, +pub struct FusionServer { + streams: MultiStream, + pub(crate) handles: HandleContainer, + pub device: R::FusionDevice, } -impl FusionServer +impl FusionServer where - B: FusionBackend, + R: FusionRuntime, { - pub fn new(device: B::Device) -> Self { + pub fn new(device: R::FusionDevice) -> Self { Self { streams: MultiStream::new(device.clone()), - handles: HandleContainer::new(device.clone()), + handles: HandleContainer::new(), device, } } @@ -33,7 +30,7 @@ where &mut self, streams: Vec, desc: OperationDescription, - operation: Box>, + operation: Box>, ) { self.streams .register(streams, desc, operation, &mut self.handles) @@ -47,90 +44,110 @@ where self.handles.create_tensor_uninit() } - pub fn read_float( + pub fn read_float( &mut self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader, D>> { + ) -> burn_tensor::Reader, D>> + where + B: FusionBackend, + { // Make sure all registered operations are executed. // The underlying backend can still be async. self.drain_stream(id); - let tensor = self.handles.get_float_tensor(&tensor); + let tensor = self.handles.get_float_tensor::(&tensor); B::float_into_data(tensor) } - pub fn read_int( + pub fn read_int( &mut self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader, D>> { + ) -> burn_tensor::Reader, D>> + where + B: FusionBackend, + { // Make sure all registered operations are executed. // The underlying backend can still be async. self.drain_stream(id); - let tensor = self.handles.get_int_tensor(&tensor); + let tensor = self.handles.get_int_tensor::(&tensor); B::int_into_data(tensor) } - pub fn read_bool( + pub fn read_bool( &mut self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader> { + ) -> burn_tensor::Reader> + where + B: FusionBackend, + { // Make sure all registered operations are executed. // The underlying backend can still be async. self.drain_stream(id); - let tensor = self.handles.get_bool_tensor(&tensor); + let tensor = self.handles.get_bool_tensor::(&tensor); B::bool_into_data(tensor) } - pub fn change_server_float( + pub fn change_server_float( &mut self, tensor: &TensorDescription, - device: &B::Device, + device: &R::FusionDevice, server_device: &mut Self, - ) -> Arc { - let tensor = self.handles.get_float_tensor::(tensor); + ) -> Arc + where + B: FusionBackend, + { + let tensor = self.handles.get_float_tensor::(tensor); let tensor = B::float_to_device(tensor, device); let id = server_device.create_empty_handle(); server_device .handles - .register_float_tensor(&id, tensor.clone()); + .register_float_tensor::(&id, tensor.clone()); id } - pub fn change_server_int( + + pub fn change_server_int( &mut self, tensor: &TensorDescription, - device: &B::Device, + device: &R::FusionDevice, server_device: &mut Self, - ) -> Arc { - let tensor = self.handles.get_int_tensor::(tensor); + ) -> Arc + where + B: FusionBackend, + { + let tensor = self.handles.get_int_tensor::(tensor); let tensor = B::int_to_device(tensor, device); let id = server_device.create_empty_handle(); server_device .handles - .register_int_tensor(&id, tensor.clone()); + .register_int_tensor::(&id, tensor.clone()); id } - pub fn change_server_bool( + + pub fn change_server_bool( &mut self, tensor: &TensorDescription, - device: &B::Device, + device: &R::FusionDevice, server_device: &mut Self, - ) -> Arc { - let tensor = self.handles.get_bool_tensor::(tensor); + ) -> Arc + where + B: FusionBackend, + { + let tensor = self.handles.get_bool_tensor::(tensor); let tensor = B::bool_to_device(tensor, device); let id = server_device.create_empty_handle(); server_device .handles - .register_bool_tensor(&id, tensor.clone()); + .register_bool_tensor::(&id, tensor.clone()); id } diff --git a/crates/burn-fusion/src/stream/base.rs b/crates/burn-fusion/src/stream/base.rs index fb3d8f99b3..31ebfb6146 100644 --- a/crates/burn-fusion/src/stream/base.rs +++ b/crates/burn-fusion/src/stream/base.rs @@ -1,18 +1,16 @@ -use burn_tensor::repr::OperationDescription; - -use crate::FusionBackend; - use super::{execution::Operation, OperationConverter, RelativeOps}; +use crate::FusionRuntime; +use burn_tensor::repr::OperationDescription; /// A growing list of [tensor operation descriptions](OperationDescription). -pub struct OperationQueue { +pub struct OperationQueue { pub(crate) global: Vec, pub(crate) relative: Vec, pub(crate) converter: OperationConverter, - pub(crate) operations: Vec>>, + pub(crate) operations: Vec>>, } -impl Default for OperationQueue { +impl Default for OperationQueue { fn default() -> Self { Self::new() } @@ -56,7 +54,7 @@ impl core::fmt::Display for StreamId { } } -impl OperationQueue { +impl OperationQueue { /// Create a new empty queue. pub fn new() -> Self { Self { @@ -72,7 +70,7 @@ impl OperationQueue { /// The new [operation description](OperationDescription) will be converted to a local /// representation that can be reused when the same pattern emerge in different but similar /// scenario, so that the same optimization can be used. - pub fn add(&mut self, global: OperationDescription, operation: Box>) { + pub fn add(&mut self, global: OperationDescription, operation: Box>) { let relative = global.to_relative(&mut self.converter); self.relative.push(relative); self.global.push(global); diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 11441a40a8..a7d2d0454f 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -1,4 +1,3 @@ -use crate::FusionBackend; use burn_tensor::{repr::*, Element, ElementConversion}; use hashbrown::HashMap; @@ -9,11 +8,11 @@ use hashbrown::HashMap; /// It also contains all scalar values, which can change even for the same graph. They are sorted /// in the order in which they appear in the graph. #[derive(new)] -pub struct Context<'a, B: FusionBackend> { +pub struct Context<'a, H> { /// The tensor mapping where local tensor id points to the updated tensor description. pub tensors: &'a HashMap, /// Handle container to retrieve tensors based on their description. - pub handles: &'a mut HandleContainer, + pub handles: &'a mut HandleContainer, /// Float scalars found in the graph in the order they appeared. pub scalar_floats: &'a Vec, /// Int scalars found in the graph in the order they appeared. @@ -42,10 +41,7 @@ trait RelativeOpsScalar { } impl OperationConverter { - pub(crate) fn context<'a, B: FusionBackend>( - &'a self, - handles: &'a mut HandleContainer, - ) -> Context<'a, B> { + pub(crate) fn context<'a, H>(&'a self, handles: &'a mut HandleContainer) -> Context<'a, H> { Context { handles, tensors: &self.tensors_relative2global, @@ -853,6 +849,12 @@ impl RelativeOps for BaseOperationDescription { out: desc.out.to_relative(converter), }) } + BaseOperationDescription::Cast(desc) => { + BaseOperationDescription::Cast(UnaryOperationDescription { + input: desc.input.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } } } } @@ -887,6 +889,7 @@ impl RelativeOps for TensorDescription { id: relative_id, shape: relative_shape, status: self.status.clone(), + dtype: self.dtype, }; // We update both mappings. @@ -904,7 +907,10 @@ impl RelativeOps for TensorDescription { #[cfg(test)] mod tests { use super::*; - use burn_tensor::repr::{TensorDescription, TensorId, TensorStatus}; + use burn_tensor::{ + repr::{TensorDescription, TensorId, TensorStatus}, + DType, + }; #[test] fn tensor_description_to_relative() { @@ -912,11 +918,13 @@ mod tests { id: TensorId::new(500), shape: vec![512, 32, 2048], status: TensorStatus::ReadOnly, + dtype: DType::F32, }; let tensor2 = TensorDescription { id: TensorId::new(501), shape: vec![512, 128, 2048], status: TensorStatus::ReadOnly, + dtype: DType::F32, }; let mut converter = OperationConverter::default(); let tensor1_local = tensor1.to_relative(&mut converter); @@ -927,7 +935,8 @@ mod tests { TensorDescription { id: TensorId::new(0), shape: vec![0, 1, 2], - status: TensorStatus::ReadOnly + status: TensorStatus::ReadOnly, + dtype: DType::F32 } ); assert_eq!( @@ -935,7 +944,8 @@ mod tests { TensorDescription { id: TensorId::new(1), shape: vec![0, 3, 2], - status: TensorStatus::ReadOnly + status: TensorStatus::ReadOnly, + dtype: DType::F32 } ); } diff --git a/crates/burn-fusion/src/stream/execution/base.rs b/crates/burn-fusion/src/stream/execution/base.rs index 7b2dce0f50..d24733657b 100644 --- a/crates/burn-fusion/src/stream/execution/base.rs +++ b/crates/burn-fusion/src/stream/execution/base.rs @@ -5,7 +5,7 @@ use crate::{ store::{ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy}, OperationQueue, RelativeOps, }, - FusionBackend, Optimization, + FusionRuntime, Optimization, }; /// The mode in which the execution is done. @@ -16,18 +16,18 @@ pub(crate) enum ExecutionMode { } /// General trait to abstract how a single operation is executed. -pub trait Operation: Send + Sync { +pub trait Operation: Send + Sync { /// Execute the operation. - fn execute(self: Box, handles: &mut HandleContainer); + fn execute(self: Box, handles: &mut HandleContainer); } -impl OperationQueue { +impl OperationQueue { /// Execute the queue partially following the execution strategy from the plan. pub(crate) fn execute( &mut self, id: ExecutionPlanId, - handles: &mut HandleContainer, - store: &mut ExecutionPlanStore, + handles: &mut HandleContainer, + store: &mut ExecutionPlanStore, ) { match &mut store.get_mut_unchecked(id).strategy { ExecutionStrategy::Optimization(optimization) => { @@ -39,8 +39,8 @@ impl OperationQueue { fn execute_optimization( &mut self, - handles: &mut HandleContainer, - optimization: &mut B::Optimization, + handles: &mut HandleContainer, + optimization: &mut R::Optimization, ) { let num_drained = optimization.len(); @@ -51,7 +51,7 @@ impl OperationQueue { self.operations.drain(0..num_drained); } - fn execute_operations(&mut self, handles: &mut HandleContainer) { + fn execute_operations(&mut self, handles: &mut HandleContainer) { let num_drained = self.operations.len(); for operation in self.operations.drain(0..num_drained) { @@ -61,7 +61,7 @@ impl OperationQueue { self.drain_stream(num_drained, handles); } - fn drain_stream(&mut self, num_drained: usize, handles: &mut HandleContainer) { + fn drain_stream(&mut self, num_drained: usize, handles: &mut HandleContainer) { self.global[0..num_drained] .iter() .flat_map(|desc| desc.nodes()) diff --git a/crates/burn-fusion/src/stream/execution/policy.rs b/crates/burn-fusion/src/stream/execution/policy.rs index e56f7bce62..5424ec5364 100644 --- a/crates/burn-fusion/src/stream/execution/policy.rs +++ b/crates/burn-fusion/src/stream/execution/policy.rs @@ -265,9 +265,12 @@ impl Policy { #[cfg(test)] mod tests { - use burn_tensor::repr::{ - FloatOperationDescription, TensorDescription, TensorId, TensorStatus, - UnaryOperationDescription, + use burn_tensor::{ + repr::{ + FloatOperationDescription, TensorDescription, TensorId, TensorStatus, + UnaryOperationDescription, + }, + DType, }; use super::*; @@ -557,6 +560,7 @@ mod tests { id: TensorId::new(id), shape: vec![32, 32, 1], status: TensorStatus::NotInit, + dtype: DType::F32, }); } diff --git a/crates/burn-fusion/src/stream/execution/tests.rs b/crates/burn-fusion/src/stream/execution/tests.rs index 6755b624b2..0e13546e17 100644 --- a/crates/burn-fusion/src/stream/execution/tests.rs +++ b/crates/burn-fusion/src/stream/execution/tests.rs @@ -6,10 +6,13 @@ //! To test these components effectively, we create mock types for the stream, optimization, //! optimization builder, and stream segment. These mock types aid in comprehensively //! understanding the process of optimizing streams. -use burn_tensor::repr::{ - BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription, - OperationDescription, ScalarOperationDescription, TensorDescription, TensorId, TensorStatus, - UnaryOperationDescription, +use burn_tensor::{ + repr::{ + BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription, + OperationDescription, ScalarOperationDescription, TensorDescription, TensorId, + TensorStatus, UnaryOperationDescription, + }, + DType, }; use crate::{ @@ -523,16 +526,19 @@ fn operation_1() -> OperationDescription { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, rhs: TensorDescription { id: TensorId::new(1), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, out: TensorDescription { id: TensorId::new(2), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, }, )) @@ -546,12 +552,14 @@ fn operation_2() -> OperationDescription { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, rhs: 5.0, out: TensorDescription { id: TensorId::new(2), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, }, )) @@ -564,11 +572,13 @@ fn operation_3() -> OperationDescription { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, out: TensorDescription { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, })) } diff --git a/crates/burn-fusion/src/stream/multi.rs b/crates/burn-fusion/src/stream/multi.rs index 035b6ed280..4edcbfe1ce 100644 --- a/crates/burn-fusion/src/stream/multi.rs +++ b/crates/burn-fusion/src/stream/multi.rs @@ -5,18 +5,18 @@ use super::{ store::{ExecutionPlanId, ExecutionPlanStore}, OperationQueue, StreamId, }; -use crate::FusionBackend; +use crate::FusionRuntime; use std::collections::HashMap; /// Keep track of multiple concurrent streams of operations. -pub struct MultiStream { - streams: HashMap>, - optimizations: ExecutionPlanStore, - device: B::Device, +pub struct MultiStream { + streams: HashMap>, + optimizations: ExecutionPlanStore, + device: R::FusionDevice, } -impl MultiStream { - pub(crate) fn new(device: B::Device) -> Self { +impl MultiStream { + pub(crate) fn new(device: R::FusionDevice) -> Self { Self { streams: HashMap::new(), optimizations: ExecutionPlanStore::new(), @@ -29,8 +29,8 @@ impl MultiStream { &mut self, streams: Vec, desc: OperationDescription, - operation: Box>, - handles: &mut HandleContainer, + operation: Box>, + handles: &mut HandleContainer, ) { let id = self.maybe_drain(streams, handles); @@ -65,7 +65,7 @@ impl MultiStream { } /// Drain the streams. - pub fn drain(&mut self, handles: &mut HandleContainer, id: StreamId) { + pub fn drain(&mut self, handles: &mut HandleContainer, id: StreamId) { if let Some(mut stream) = self.streams.remove(&id) { stream.processor.process( Segment::new(&mut stream.queue, handles), @@ -80,7 +80,7 @@ impl MultiStream { fn maybe_drain( &mut self, streams: Vec, - handles: &mut HandleContainer, + handles: &mut HandleContainer, ) -> StreamId { let streams = Self::remove_duplicate(streams); let current = StreamId::current(); @@ -113,7 +113,7 @@ impl MultiStream { output } - fn free_orphans(&self, handles: &mut HandleContainer) { + fn free_orphans(&self, handles: &mut HandleContainer) { let nodes = self .streams .values() @@ -126,31 +126,31 @@ impl MultiStream { } } -struct Stream { - queue: OperationQueue, - processor: Processor, +struct Stream { + queue: OperationQueue, + processor: Processor, } #[derive(new)] -struct Segment<'a, B: FusionBackend> { - queue: &'a mut OperationQueue, - handles: &'a mut HandleContainer, +struct Segment<'a, R: FusionRuntime> { + queue: &'a mut OperationQueue, + handles: &'a mut HandleContainer, } -impl<'i, B: FusionBackend> StreamSegment for Segment<'i, B> { +impl<'i, R: FusionRuntime> StreamSegment for Segment<'i, R> { fn operations(&self) -> &[OperationDescription] { &self.queue.relative } - fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore) { + fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore) { self.queue.execute(id, self.handles, store) } } -impl Stream { - fn new(device: B::Device) -> Self { +impl Stream { + fn new(device: R::FusionDevice) -> Self { Self { - processor: Processor::new(B::optimizations(device)), + processor: Processor::new(R::optimizations(device)), queue: OperationQueue::new(), } } diff --git a/crates/burn-fusion/src/stream/store/index.rs b/crates/burn-fusion/src/stream/store/index.rs index b1c3111ac9..15bb56341a 100644 --- a/crates/burn-fusion/src/stream/store/index.rs +++ b/crates/burn-fusion/src/stream/store/index.rs @@ -116,9 +116,12 @@ impl ExecutionPlanIndex { #[cfg(test)] mod tests { - use burn_tensor::repr::{ - BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription, - TensorDescription, TensorId, TensorStatus, + use burn_tensor::{ + repr::{ + BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription, + TensorDescription, TensorId, TensorStatus, + }, + DType, }; use super::*; @@ -221,16 +224,19 @@ mod tests { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, rhs: TensorDescription { id: TensorId::new(1), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, out: TensorDescription { id: TensorId::new(2), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, }, )) @@ -243,12 +249,14 @@ mod tests { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, rhs: 5.0, out: TensorDescription { id: TensorId::new(2), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, }, )) @@ -261,16 +269,19 @@ mod tests { id: TensorId::new(0), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, rhs: TensorDescription { id: TensorId::new(1), shape: vec![32, 32], status: TensorStatus::ReadOnly, + dtype: DType::F32, }, out: TensorDescription { id: TensorId::new(2), shape: vec![32, 32], status: TensorStatus::NotInit, + dtype: DType::F32, }, }, )) diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 6fad70e723..54d3f12f1c 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -1,21 +1,21 @@ -use crate::{client::FusionClient, stream::StreamId}; +use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime}; use burn_tensor::{ - backend::Backend, ops::{FloatElem, IntElem}, repr::{TensorDescription, TensorId, TensorStatus}, - Data, Reader, Shape, + DType, Data, Reader, Shape, }; use std::sync::Arc; /// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind. -#[derive(Clone)] -pub struct FusionTensor { +pub struct FusionTensor { /// Tensor id. pub id: Arc, /// The shape of the tensor. pub shape: Vec, /// The [fusion client](FusionClient). - pub client: C, + pub client: Client, + /// The datatype of the tensor. + pub dtype: DType, // Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`. // // When a tensor is dropped and is still an orphan, we need to register it as such to avoid @@ -24,15 +24,27 @@ pub struct FusionTensor { pub(crate) stream: StreamId, } -impl core::fmt::Debug for FusionTensor { +impl Clone for FusionTensor { + fn clone(&self) -> Self { + Self { + id: self.id.clone(), + shape: self.shape.clone(), + client: self.client.clone(), + dtype: self.dtype, + is_orphan: self.is_orphan, + stream: self.stream, + } + } +} + +impl core::fmt::Debug for FusionTensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str( format!( - "{{ id: {:?}, shape: {:?}, should_drop: {:?}, backend: {:?}, device: {:?} }}", + "{{ id: {:?}, shape: {:?}, should_drop: {:?}, device: {:?} }}", self.id, self.shape, self.is_orphan, - ::name(), self.client.device().clone(), ) .as_str(), @@ -40,12 +52,19 @@ impl core::fmt::Debug for FusionTensor { } } -impl FusionTensor { - pub(crate) fn new(id: Arc, shape: Vec, client: C, stream: StreamId) -> Self { +impl FusionTensor { + pub(crate) fn new( + id: Arc, + shape: Vec, + dtype: DType, + client: Client, + stream: StreamId, + ) -> Self { Self { id, shape, client, + dtype, is_orphan: true, stream, } @@ -68,6 +87,7 @@ impl FusionTensor { status: TensorStatus::NotInit, shape: self.shape.clone(), id: *self.id.as_ref(), + dtype: self.dtype, } } @@ -85,34 +105,42 @@ impl FusionTensor { status, shape: shape_out, id: *self.id.as_ref(), + dtype: self.dtype, } } - pub(crate) fn into_data(self) -> Reader, D>> { + pub(crate) fn into_data(self) -> Reader, D>> + where + B: FusionBackend, + { let id = self.stream; self.client .clone() - .read_tensor_float(self.into_description(), id) + .read_tensor_float::(self.into_description(), id) } - pub(crate) fn int_into_data( - self, - ) -> Reader, D>> { + pub(crate) fn int_into_data(self) -> Reader, D>> + where + B: FusionBackend, + { let id = self.stream; self.client .clone() - .read_tensor_int(self.into_description(), id) + .read_tensor_int::(self.into_description(), id) } - pub(crate) fn bool_into_data(self) -> Reader> { + pub(crate) fn bool_into_data(self) -> Reader> + where + B: FusionBackend, + { let id = self.stream; self.client .clone() - .read_tensor_bool(self.into_description(), id) + .read_tensor_bool::(self.into_description(), id) } } -impl Drop for FusionTensor { +impl Drop for FusionTensor { fn drop(&mut self) { if !self.is_orphan { return; diff --git a/crates/burn-import/src/burn/codegen.rs b/crates/burn-import/src/burn/codegen.rs index ed617e8086..734aab2d11 100644 --- a/crates/burn-import/src/burn/codegen.rs +++ b/crates/burn-import/src/burn/codegen.rs @@ -9,10 +9,10 @@ fn convert_primitive(primitive: T) -> TokenStream { value.parse().unwrap() } -fn convert_to_array<'a, I, T: ToTokens>(list: I) -> TokenStream +fn convert_to_array<'a, I, T>(list: I) -> TokenStream where I: Iterator, - T: 'a, + T: ToTokens + 'a, { let mut body = quote! {}; diff --git a/crates/burn-import/src/onnx/coalesce.rs b/crates/burn-import/src/onnx/coalesce.rs index ddf102ca5e..c3d5d93d21 100644 --- a/crates/burn-import/src/onnx/coalesce.rs +++ b/crates/burn-import/src/onnx/coalesce.rs @@ -173,5 +173,7 @@ fn convert_and_remove_add_node(bias_node: &Node, current_node: &mut Node) { // Push the bias input and update the output name current_node.inputs.push(bias_input); - current_node.outputs[0].name = bias_node.outputs[0].name.clone(); + current_node.outputs[0] + .name + .clone_from(&bias_node.outputs[0].name); } diff --git a/crates/burn-import/src/onnx/from_onnx.rs b/crates/burn-import/src/onnx/from_onnx.rs index d9579294a5..0a38722aec 100644 --- a/crates/burn-import/src/onnx/from_onnx.rs +++ b/crates/burn-import/src/onnx/from_onnx.rs @@ -322,7 +322,7 @@ impl OnnxGraphBuilder { node.node_type, self.node_name_counter[&node.node_type] ) .to_lowercase(); - node.name = new_name.clone(); + node.name.clone_from(&new_name); } fn check_constants(&mut self, node: &mut Node, i: usize, _graph_io: &mut OnnxGraphIO) { @@ -343,7 +343,7 @@ impl OnnxGraphBuilder { ); if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { // The value comes from Identity inputs - input.value = constant.inputs[0].value.clone(); + input.value.clone_from(&constant.inputs[0].value); input.ty = constant.inputs[0].ty.clone(); } else { let arg = convert_constant_value(constant); @@ -383,7 +383,7 @@ impl OnnxGraphBuilder { if let Some(identity_idx) = self.identity_idx.get(&x.name) { let input_name = &self.nodes[*identity_idx].inputs[0].name; - x.name = input_name.clone(); + x.name.clone_from(input_name); } }); } @@ -454,7 +454,7 @@ fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) { match node.outputs[0].ty { ArgType::Tensor(ref mut tensor_type) => { if let ArgType::Tensor(arg_tensor) = &out_arg.ty { - tensor_type.shape = arg_tensor.shape.clone(); + tensor_type.shape.clone_from(&arg_tensor.shape); let inner = arg_tensor .shape .clone() @@ -497,7 +497,7 @@ fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { for node_input in node.inputs.iter_mut() { if let Some(input_name) = graph_io.get_new_name(&node_input.name) { node_input.passed = true; - node_input.name = input_name.clone(); + node_input.name.clone_from(&input_name); } else { node_input.name = "".to_string(); node_input.passed = false; @@ -507,7 +507,7 @@ fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { if node.node_type == NodeType::Constant || node.node_type == NodeType::Identity { let new_name = format!("{}_out{}", node.name, out_count); graph_io.insert(&node.outputs[0], &new_name); - node.outputs[0].name = new_name.clone(); + node.outputs[0].name.clone_from(&new_name); log::debug!("Found {} constant", new_name); } else { for output in node.outputs.iter_mut() { @@ -517,7 +517,7 @@ fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { graph_io.update_name(output, &new_name); - output.name = new_name.clone(); + output.name.clone_from(&new_name); out_count += 1; } } diff --git a/crates/burn-import/src/onnx/ir.rs b/crates/burn-import/src/onnx/ir.rs index eb229ea7ce..620ffed4ec 100644 --- a/crates/burn-import/src/onnx/ir.rs +++ b/crates/burn-import/src/onnx/ir.rs @@ -29,7 +29,7 @@ impl Argument { /// Copy everything except the name from the other argument pub fn copy_value(&mut self, other_arg: &Argument) { self.ty = other_arg.ty.clone(); - self.value = other_arg.value.clone(); + self.value.clone_from(&other_arg.value); } pub fn from_initializer(initializer: &TensorProto) -> Argument { diff --git a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs index fbd0deb19f..b007533e72 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs @@ -1,5 +1,6 @@ use super::Scope; use crate::kernel::WORKGROUP_DEFAULT; +use burn_tensor::DType; use serde::{Deserialize, Serialize}; use std::fmt::Display; @@ -48,6 +49,25 @@ impl From for Item { } } +impl From for Elem { + fn from(dtype: DType) -> Self { + match dtype { + DType::F64 => Elem::Float(FloatKind::F64), + DType::F32 => Elem::Float(FloatKind::F32), + DType::F16 => Elem::Float(FloatKind::F16), + DType::BF16 => Elem::Float(FloatKind::BF16), + DType::I64 => Elem::Int(IntKind::I64), + DType::I32 => Elem::Int(IntKind::I32), + DType::I16 => panic!("i16 isn't supported yet."), + DType::I8 => panic!("i8 isn't supported yet."), + DType::U64 => Elem::UInt, + DType::U32 => Elem::UInt, + DType::U8 => panic!("u8 isn't supported yet."), + DType::Bool => Elem::Bool, + } + } +} + impl Display for Elem { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index c5aa160aa8..693e3db2d5 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -4,8 +4,6 @@ use burn_tensor::Element; /// The base element trait for the jit backend. pub trait JitElement: burn_tensor::Element + core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod -where - Self: Sized, { /// TODO: Remove when all wgsl static kernels are migrated. fn type_name() -> &'static str; @@ -92,6 +90,27 @@ impl JitElement for f32 { } } +impl JitElement for half::f16 { + fn type_name() -> &'static str { + "f16" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } + fn gpu_elem() -> gpu::Elem { + gpu::Elem::Float(gpu::FloatKind::F16) + } + fn maximum_value() -> Self { + half::f16::MAX + } + fn minimum_value() -> Self { + half::f16::MIN + } +} + impl JitElement for half::bf16 { fn type_name() -> &'static str { "bf16" @@ -114,4 +133,5 @@ impl JitElement for half::bf16 { } impl FloatElement for f32 {} impl FloatElement for half::bf16 {} +impl FloatElement for half::f16 {} impl IntElement for i32 {} diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index f1a49290e6..919670524e 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,12 +1,13 @@ use super::{ElementWise, ElementWiseState}; use crate::{ - element::JitElement, fusion::ElementWiseBuilder, tensor::JitTensor, FloatElement, IntElement, - JitBackend, Runtime, + element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement, + IntElement, JitBackend, Runtime, }; use burn_compute::client::ComputeClient; -use burn_fusion::{client::MutexFusionClient, FusionBackend}; +use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; +use half::{bf16, f16}; use serde::{Deserialize, Serialize}; /// Fusion optimization type for JIT. @@ -26,13 +27,11 @@ pub enum JitOptimizationState { ElementWise(ElementWiseState), } -impl burn_fusion::Optimization> for JitOptimization +impl burn_fusion::Optimization> for JitOptimization where R: Runtime, - F: FloatElement, - I: IntElement, { - fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitBackend>) { + fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle>) { match self { Self::ElementWise(op) => op.execute(context), } @@ -102,15 +101,46 @@ impl ReprBackend for JitBackend FusionBackend for JitBackend { +impl FusionRuntime for FusionJitRuntime { type OptimizationState = JitOptimizationState; type Optimization = JitOptimization; + type FusionHandle = JitFusionHandle; + type FusionDevice = R::Device; type FusionClient = MutexFusionClient; fn optimizations( device: R::Device, ) -> Vec>> { - vec![Box::new(ElementWiseBuilder::::new(device))] + vec![Box::new(ElementWiseBuilder::::new(device))] + } +} + +#[derive(Debug)] +pub struct FusionJitRuntime { + _b: PhantomData, +} + +impl FusionBackend for JitBackend { + type FusionRuntime = FusionJitRuntime; + + type FullPrecisionBackend = JitBackend; + + fn cast_float( + tensor: burn_tensor::ops::FloatTensor, + dtype: burn_tensor::DType, + ) -> Self::Handle { + fn cast( + tensor: JitTensor, + ) -> JitFusionHandle { + JitFusionHandle::from(kernel::cast::(tensor)) + } + + match dtype { + burn_tensor::DType::F32 => cast::(tensor), + burn_tensor::DType::F16 => cast::(tensor), + burn_tensor::DType::BF16 => cast::(tensor), + _ => panic!("Casting error: {dtype:?} unsupported."), + } } } diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 0cd346af70..bea4cbf5fb 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -1,42 +1,31 @@ -use core::marker::PhantomData; - use super::{optimization::ElementWise, CompilationPhase}; use crate::{ codegen::dialect::gpu::{ - BinaryOperator, ConditionalAssign, Elem, Operator, Procedure, UnaryOperator, Variable, + BinaryOperator, ConditionalAssign, Operator, Procedure, UnaryOperator, Variable, }, - element::JitElement, fusion::{tracing::TraceBuilder, JitOptimization}, - FloatElement, IntElement, JitBackend, Runtime, + Runtime, }; use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus}; use burn_tensor::{ - ops::{FloatElem, IntElem}, repr::{ BaseOperationDescription, BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription, OperationDescription, ScalarOperationDescription, TensorDescription, UnaryOperationDescription, }, - Device, Element, + Element, }; /// Fused element wise operations that are normally memory bound. -pub(crate) struct ElementWiseBuilder { +pub(crate) struct ElementWiseBuilder { builder: TraceBuilder, current_output_shape: Vec, status: OptimizationStatus, num_added: usize, device: R::Device, - _float_elem: PhantomData, - _int_elem: PhantomData, } -impl OptimizationBuilder> for ElementWiseBuilder -where - R: Runtime, - F: FloatElement, - I: IntElement, -{ +impl OptimizationBuilder> for ElementWiseBuilder { fn register(&mut self, ops: &OperationDescription) { if let OptimizationStatus::Closed = self.status { return; @@ -44,31 +33,31 @@ where match ops { OperationDescription::BaseFloat(ops) => { - if !self.register_base::>>(ops) { + if !self.register_base(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::BaseInt(ops) => { - if !self.register_base::>>(ops) { + if !self.register_base(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::Float(ops) => { - if !self.register_float::>>(ops) { + if !self.register_float(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::NumericFloat(ops) => { - if !self.register_numeric::>, _>(ops) { + if !self.register_numeric::(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::NumericInt(ops) => { - if !self.register_numeric::>, _>(ops) { + if !self.register_numeric::(ops) { self.status = OptimizationStatus::Closed; return; } @@ -119,185 +108,149 @@ where } } -impl ElementWiseBuilder { - pub fn new(device: Device>) -> Self { +impl ElementWiseBuilder { + pub fn new(device: R::Device) -> Self { Self { builder: TraceBuilder::new(), num_added: 0, current_output_shape: Vec::new(), status: OptimizationStatus::Open, device, - _float_elem: PhantomData, - _int_elem: PhantomData, } } - fn register_base(&mut self, ops: &BaseOperationDescription) -> bool { + fn register_base(&mut self, ops: &BaseOperationDescription) -> bool { match ops { - BaseOperationDescription::Equal(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Equal(BinaryOperator { lhs, rhs, out }), - ), + BaseOperationDescription::Equal(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Equal(BinaryOperator { lhs, rhs, out }) + }), + BaseOperationDescription::Cast(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Assign(UnaryOperator { input, out }) + }), _ => false, } } - fn register_float(&mut self, ops: &FloatOperationDescription) -> bool { + fn register_float(&mut self, ops: &FloatOperationDescription) -> bool { match ops { - FloatOperationDescription::Exp(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Exp(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Log(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Log(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Log1p(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { + FloatOperationDescription::Exp(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Exp(UnaryOperator { input, out }) + }), + FloatOperationDescription::Log(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Log(UnaryOperator { input, out }) + }), + FloatOperationDescription::Log1p(desc) => self + .register_unary_ops(desc, |input, out| { Operator::Log1p(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Cos(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Cos(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Sin(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Sin(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::PowfScalar(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Powf(BinaryOperator { lhs, rhs, out }), - ), - FloatOperationDescription::Tanh(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Tanh(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Erf(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { - Operator::Erf(UnaryOperator { input, out }) - }) - } - FloatOperationDescription::Recip(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { + }), + FloatOperationDescription::Cos(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Cos(UnaryOperator { input, out }) + }), + FloatOperationDescription::Sin(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Sin(UnaryOperator { input, out }) + }), + FloatOperationDescription::PowfScalar(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Powf(BinaryOperator { lhs, rhs, out }) + }), + FloatOperationDescription::Tanh(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Tanh(UnaryOperator { input, out }) + }), + FloatOperationDescription::Erf(desc) => self.register_unary_ops(desc, |input, out| { + Operator::Erf(UnaryOperator { input, out }) + }), + FloatOperationDescription::Recip(desc) => self + .register_unary_ops(desc, |input, out| { Operator::Recip(UnaryOperator { input, out }) - }) - } + }), _ => false, } } - fn register_numeric( - &mut self, - ops: &NumericOperationDescription, - ) -> bool { + fn register_numeric(&mut self, ops: &NumericOperationDescription) -> bool { match ops { - NumericOperationDescription::Add(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Add(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::AddScalar(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Add(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::Sub(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Sub(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::SubScalar(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Sub(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::Mul(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Mul(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::MulScalar(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Mul(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::Div(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Div(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::DivScalar(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), E::gpu_elem()), - |lhs, rhs, out| Operator::Div(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::Abs(desc) => { - self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| { + NumericOperationDescription::Add(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Add(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::AddScalar(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Add(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::Sub(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Sub(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::SubScalar(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Sub(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::Mul(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Mul(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::MulScalar(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Mul(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::Div(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Div(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::DivScalar(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Div(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::Abs(desc) => self + .register_unary_ops(desc, |input, out| { Operator::Abs(UnaryOperator { input, out }) - }) - } - NumericOperationDescription::Lower(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Lower(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::LowerElem(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Lower(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::Greater(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Greater(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::GreaterElem(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Greater(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::LowerEqual(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::LowerEqual(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::LowerEqualElem(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::LowerEqual(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::GreaterEqual(desc) => self.register_binary_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::GreaterEqualElem(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }), - ), - NumericOperationDescription::EqualElem(desc) => self.register_scalar_ops( - desc, - (E::gpu_elem(), E::gpu_elem(), Elem::Bool), - |lhs, rhs, out| Operator::Equal(BinaryOperator { lhs, rhs, out }), - ), + }), + NumericOperationDescription::Lower(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Lower(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::LowerElem(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Lower(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::Greater(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::Greater(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::GreaterElem(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Greater(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::LowerEqual(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::LowerEqual(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::LowerEqualElem(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::LowerEqual(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::GreaterEqual(desc) => self + .register_binary_ops(desc, |lhs, rhs, out| { + Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::GreaterEqualElem(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }) + }), + NumericOperationDescription::EqualElem(desc) => self + .register_scalar_ops(desc, |lhs, rhs, out| { + Operator::Equal(BinaryOperator { lhs, rhs, out }) + }), NumericOperationDescription::MaskWhere(desc) => { if !self.output_is_compatible(&desc.out) { return false; } - let cond = self.builder.input(&desc.mask, Elem::Bool); - let lhs = self.builder.input(&desc.value, E::gpu_elem()); - let rhs = self.builder.input(&desc.tensor, E::gpu_elem()); - let out = self.builder.output(&desc.out, E::gpu_elem()); + let cond = self.builder.input(&desc.mask); + let lhs = self.builder.input(&desc.value); + let rhs = self.builder.input(&desc.tensor); + let out = self.builder.output(&desc.out); self.builder .register_operation(Procedure::ConditionalAssign(ConditionalAssign { @@ -314,10 +267,10 @@ impl ElementWiseBuilder { return false; } - let cond = self.builder.input(&desc.mask, Elem::Bool); - let lhs = self.builder.scalar(&desc.value, E::gpu_elem()); - let rhs = self.builder.input(&desc.tensor, E::gpu_elem()); - let out = self.builder.output(&desc.out, E::gpu_elem()); + let cond = self.builder.input(&desc.mask); + let lhs = self.builder.scalar(&desc.value, desc.out.dtype.into()); + let rhs = self.builder.input(&desc.tensor); + let out = self.builder.output(&desc.out); self.builder .register_operation(Procedure::ConditionalAssign(ConditionalAssign { @@ -334,8 +287,8 @@ impl ElementWiseBuilder { return false; } - let input = Variable::ConstantScalar(1.0, E::gpu_elem()); - let out = self.builder.output(desc, E::gpu_elem()); + let input = Variable::ConstantScalar(1.0, desc.dtype.into()); + let out = self.builder.output(desc); self.builder .register_operation(Operator::Assign(UnaryOperator { input, out })); @@ -347,8 +300,8 @@ impl ElementWiseBuilder { return false; } - let input = Variable::ConstantScalar(0.0, E::gpu_elem()); - let out = self.builder.output(desc, E::gpu_elem()); + let input = Variable::ConstantScalar(0.0, desc.dtype.into()); + let out = self.builder.output(desc); self.builder .register_operation(Operator::Assign(UnaryOperator { input, out })); @@ -360,8 +313,8 @@ impl ElementWiseBuilder { return false; } - let input = self.builder.scalar(elem, E::gpu_elem()); - let out = self.builder.output(desc, E::gpu_elem()); + let input = self.builder.scalar(elem, desc.dtype.into()); + let out = self.builder.output(desc); self.builder .register_operation(Operator::Assign(UnaryOperator { input, out })); @@ -372,12 +325,7 @@ impl ElementWiseBuilder { } } - fn register_binary_ops( - &mut self, - desc: &BinaryOperationDescription, - (elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem), - func: Func, - ) -> bool + fn register_binary_ops(&mut self, desc: &BinaryOperationDescription, func: Func) -> bool where Func: Fn(Variable, Variable, Variable) -> Operator, { @@ -385,21 +333,16 @@ impl ElementWiseBuilder { return false; } - let lhs = self.builder.input(&desc.lhs, elem_lhs); - let rhs = self.builder.input(&desc.rhs, elem_rhs); - let out = self.builder.output(&desc.out, elem_out); + let lhs = self.builder.input(&desc.lhs); + let rhs = self.builder.input(&desc.rhs); + let out = self.builder.output(&desc.out); self.builder.register_operation(func(lhs, rhs, out)); true } - fn register_unary_ops( - &mut self, - desc: &UnaryOperationDescription, - (elem_input, elem_out): (Elem, Elem), - func: Func, - ) -> bool + fn register_unary_ops(&mut self, desc: &UnaryOperationDescription, func: Func) -> bool where Func: Fn(Variable, Variable) -> Operator, { @@ -407,8 +350,8 @@ impl ElementWiseBuilder { return false; } - let input = self.builder.input(&desc.input, elem_input); - let out = self.builder.output(&desc.out, elem_out); + let input = self.builder.input(&desc.input); + let out = self.builder.output(&desc.out); self.builder.register_operation(func(input, out)); @@ -418,7 +361,6 @@ impl ElementWiseBuilder { fn register_scalar_ops( &mut self, desc: &ScalarOperationDescription, - (elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem), func: Func, ) -> bool where @@ -428,9 +370,9 @@ impl ElementWiseBuilder { return false; } - let lhs = self.builder.input(&desc.lhs, elem_lhs); - let rhs = self.builder.scalar(&desc.rhs, elem_rhs); - let out = self.builder.output(&desc.out, elem_out); + let lhs = self.builder.input(&desc.lhs); + let rhs = self.builder.scalar(&desc.rhs, desc.lhs.dtype.into()); + let out = self.builder.output(&desc.out); self.builder.register_operation(func(lhs, rhs, out)); @@ -439,7 +381,7 @@ impl ElementWiseBuilder { fn output_is_compatible(&mut self, out: &TensorDescription) -> bool { if self.current_output_shape.is_empty() { - self.current_output_shape = out.shape.clone(); + self.current_output_shape.clone_from(&out.shape); } else if self.current_output_shape != out.shape { return false; } diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index ac6010bbdb..afa5198230 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -7,8 +7,8 @@ use super::{ use crate::{ codegen::dialect::gpu::WorkgroupSize, compute::JitAutotuneKey, - fusion::{kernel::FusionKernel, tracing::Trace}, - FloatElement, IntElement, JitBackend, Runtime, + fusion::{kernel::FusionKernel, tracing::Trace, JitFusionHandle}, + Runtime, }; use burn_common::id::IdGenerator; use burn_compute::client::ComputeClient; @@ -66,10 +66,7 @@ impl ElementWise { } impl ElementWise> { - pub(crate) fn execute( - &mut self, - context: &mut Context<'_, JitBackend>, - ) { + pub(crate) fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { let client = R::client(&self.device); let key = JitAutotuneKey::FusionElemWise(FusionElemWiseAutotuneKey::new( @@ -84,9 +81,9 @@ impl ElementWise> { } } - fn run_kernel( + fn run_kernel( &mut self, - context: &mut Context<'_, JitBackend>, + context: &mut Context<'_, JitFusionHandle>, client: ComputeClient, fastest_set_index: usize, ) { @@ -109,9 +106,9 @@ impl ElementWise> { kernel.execute(); } - fn run_autotune( + fn run_autotune( &mut self, - context: &mut Context<'_, JitBackend>, + context: &mut Context<'_, JitFusionHandle>, client: ComputeClient, key: JitAutotuneKey, ) { @@ -155,9 +152,9 @@ impl ElementWise> { } /// The first output is chosen when possible, otherwise the first input is chosen. - pub(crate) fn autotune_shape<'a, F: FloatElement, I: IntElement>( + pub(crate) fn autotune_shape<'a>( &self, - context: &mut Context<'a, JitBackend>, + context: &mut Context<'a, JitFusionHandle>, ) -> &'a [usize] { let info = self.trace.running(); diff --git a/crates/burn-jit/src/fusion/kernel.rs b/crates/burn-jit/src/fusion/kernel.rs index c15a417201..334054adf2 100644 --- a/crates/burn-jit/src/fusion/kernel.rs +++ b/crates/burn-jit/src/fusion/kernel.rs @@ -10,9 +10,6 @@ use crate::fusion::strides_dyn_rank; use crate::fusion::JitFusionHandle; use crate::gpu::ComputeShader; use crate::kernel::GpuComputeShaderPhase; -use crate::FloatElement; -use crate::IntElement; -use crate::JitBackend; use crate::Runtime; use burn_compute::client::ComputeClient; use burn_compute::server::Binding; @@ -20,7 +17,6 @@ use burn_compute::tune::AutotuneOperation; use burn_fusion::stream::Context; use burn_tensor::repr::TensorDescription; use burn_tensor::repr::TensorStatus; -use burn_tensor::Device; use std::marker::PhantomData; use std::sync::Arc; @@ -109,18 +105,16 @@ impl From> for AutotunableKernel { } impl FusionKernel { - pub fn create( + pub fn create( factory: &K, running_info: &ExecutionInfo<'_>, - context: &mut Context<'_, JitBackend>, - device: Device>, + context: &mut Context<'_, JitFusionHandle>, + device: R::Device, client: ComputeClient, stateful: bool, ) -> ExecutableKernel where K: FusionKernelFactory, - F: FloatElement, - I: IntElement, { let (handles_input, inputs_description_updated, outputs_description_updated) = process_inputs_outputs( @@ -287,10 +281,10 @@ fn register_info_tensor( } } -fn process_inputs_outputs<'a, R, F, I>( +fn process_inputs_outputs<'a, R>( inputs: &[&TensorDescription], outputs: &[&TensorDescription], - context: &'a mut Context<'_, JitBackend>, + context: &'a mut Context<'_, JitFusionHandle>, stateful: bool, ) -> ( Vec>, @@ -299,8 +293,6 @@ fn process_inputs_outputs<'a, R, F, I>( ) where R: Runtime, - F: FloatElement, - I: IntElement, { let mut inputs_description_updated = Vec::with_capacity(inputs.len()); let mut outputs_description_updated = Vec::with_capacity(outputs.len()); diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index 1cec52ddac..3a7f969fd7 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -37,8 +37,9 @@ impl TraceBuilder { } /// Create a variable from an input [tensor description](TensorDescription). - pub fn input(&mut self, tensor: &TensorDescription, elem: gpu::Elem) -> gpu::Variable { + pub fn input(&mut self, tensor: &TensorDescription) -> gpu::Variable { let already_exists = self.tensors.contains_key(&tensor.id); + let elem = tensor.dtype.into(); let variable = match already_exists { false => { @@ -72,7 +73,8 @@ impl TraceBuilder { } /// Create a variable from an output [tensor description](TensorDescription). - pub fn output(&mut self, tensor: &TensorDescription, elem: gpu::Elem) -> gpu::Variable { + pub fn output(&mut self, tensor: &TensorDescription) -> gpu::Variable { + let elem = tensor.dtype.into(); // Update the tensor description to the new version. self.tensors.insert(tensor.id, (tensor.clone(), elem)); diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 60bb659daa..ae6cfb55aa 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -85,7 +85,6 @@ impl Default for Tiling2dConfig { } /// The strategy to be used when launching a matmul kernel. -#[derive(Default)] pub enum MatmulStrategy { /// A simple kernel will be used with memory coalescing optimization. Simple { @@ -100,11 +99,17 @@ pub enum MatmulStrategy { Tiling2dPadded(Tiling2dConfig), #[cfg(feature = "autotune")] /// Using autotune to chose the best kernel based on runtime information. - #[default] Autotune, } +#[allow(clippy::derivable_impls)] // Necessary otherwise the feature flags dont' work. #[cfg(feature = "autotune")] +impl Default for MatmulStrategy { + fn default() -> Self { + MatmulStrategy::Autotune + } +} + #[cfg(not(feature = "autotune"))] impl Default for MatmulStrategy { fn default() -> Self { diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index 3250fb54c4..2ef3b17d0a 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -20,11 +20,6 @@ use super::{ Tiling2dConfig, }; -#[derive(new, Debug)] -struct MatmulTiling2d { - _elem: PhantomData, -} - #[derive(new, Debug)] struct MatmulTiling2dEagerKernel { config: Tiling2dConfig, diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index 89cf6b8fc2..1fff72f176 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -1,5 +1,4 @@ use crate::{ - backend::Backend, repr::{ backend::ReprBackend, tensor::{TensorDescription, TensorId, TensorStatus}, @@ -11,36 +10,33 @@ use std::{collections::HashMap, sync::Arc}; /// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources /// are used optimally. #[derive(Default)] -pub struct HandleContainer { - handles: HashMap>, +pub struct HandleContainer { + handles: HashMap>, counter: u64, /// Handle candidates to be freed. pub handles_orphan: Vec, - /// The device on which all tensors are held. - pub device: B::Device, } /// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state -pub enum Handle { +pub enum Handle { /// No [tensor handle](ReprBackend::Handle) has been created yet NotInit, /// A [tensor handle](ReprBackend::Handle) has been created - Existing(B::Handle), + Existing(H), } -impl HandleContainer { +impl HandleContainer { /// Create a new HandleContainer - pub fn new(device_handle: B::Device) -> Self { + pub fn new() -> Self { Self { handles: HashMap::new(), handles_orphan: Vec::new(), counter: 0, - device: device_handle.clone(), } } /// Register a handle for the given [tensor id](TensorId). - pub fn register_handle(&mut self, id: TensorId, handle: B::Handle) { + pub fn register_handle(&mut self, id: TensorId, handle: H) { self.handles.insert(id, Handle::Existing(handle)); } @@ -51,7 +47,7 @@ impl HandleContainer { /// /// Make sure the status corresponds to the operation you want to execute the handle on, /// otherwise you might remove a tensor handle that will be required in the future. - pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> B::Handle { + pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H { let (id, handle) = self .handles .remove_entry(id) @@ -70,68 +66,83 @@ impl HandleContainer { } } - /// Get the [float tensor](Backend::FloatTensorPrimitive) corresponding to the + /// Get the [float tensor](crate::backend::Backend::FloatTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). - pub fn get_float_tensor( + pub fn get_float_tensor( &mut self, tensor: &TensorDescription, - ) -> B::FloatTensorPrimitive { + ) -> B::FloatTensorPrimitive + where + B: ReprBackend, + { B::float_tensor::( self.get_handle(&tensor.id, &tensor.status), Shape::from(&tensor.shape), ) } - /// Get the [int tensor](Backend::IntTensorPrimitive) corresponding to the + /// Get the [int tensor](crate::backend::Backend::IntTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). - pub fn get_int_tensor( + pub fn get_int_tensor( &mut self, tensor: &TensorDescription, - ) -> B::IntTensorPrimitive { + ) -> B::IntTensorPrimitive + where + B: ReprBackend, + { B::int_tensor::( self.get_handle(&tensor.id, &tensor.status), Shape::from(&tensor.shape), ) } - /// Get the [bool tensor](Backend::BoolTensorPrimitive) corresponding to the + /// Get the [bool tensor](crate::backend::Backend::BoolTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). - pub fn get_bool_tensor( + pub fn get_bool_tensor( &mut self, tensor: &TensorDescription, - ) -> B::BoolTensorPrimitive { + ) -> B::BoolTensorPrimitive + where + B: ReprBackend, + { B::bool_tensor::( self.get_handle(&tensor.id, &tensor.status), Shape::from(&tensor.shape), ) } - /// Register a new [float tensor](Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_float_tensor( + /// Register a new [float tensor](crate::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). + pub fn register_float_tensor( &mut self, id: &TensorId, tensor: B::FloatTensorPrimitive, - ) { + ) where + B: ReprBackend, + { let handle = B::float_tensor_handle::(tensor); self.handles.insert(*id, Handle::Existing(handle)); } - /// Register a new [int tensor](Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_int_tensor( + /// Register a new [int tensor](crate::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). + pub fn register_int_tensor( &mut self, id: &TensorId, tensor: B::IntTensorPrimitive, - ) { + ) where + B: ReprBackend, + { let handle = B::int_tensor_handle::(tensor); self.handles.insert(*id, Handle::Existing(handle)); } - /// Register a new [bool tensor](Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). - pub fn register_bool_tensor( + /// Register a new [bool tensor](crate::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). + pub fn register_bool_tensor( &mut self, id: &TensorId, tensor: B::BoolTensorPrimitive, - ) { + ) where + B: ReprBackend, + { let handle = B::bool_tensor_handle::(tensor); self.handles.insert(*id, Handle::Existing(handle)); } diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 220d8a90cf..22811d667d 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -195,6 +195,8 @@ pub enum BaseOperationDescription { /// Int => [cat](crate::ops::IntTensorOps::int_cat). /// Bool => [cat](crate::ops::BoolTensorOps::bool_cat). Cat(CatOperationDescription), + /// Cast operation, no direct operation and should be supported by fusion backend. + Cast(UnaryOperationDescription), } /// Numeric operations on int and float tensors. @@ -1102,6 +1104,7 @@ impl BaseOperationDescription { vec![&desc.tensor, &desc.out] } BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(), + BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out], } } } diff --git a/crates/burn-tensor/src/repr/tensor.rs b/crates/burn-tensor/src/repr/tensor.rs index 525ad9c50b..a68d6b9c2f 100644 --- a/crates/burn-tensor/src/repr/tensor.rs +++ b/crates/burn-tensor/src/repr/tensor.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use crate::DType; + /// The tensor unique identifier. #[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)] pub struct TensorId { @@ -35,6 +37,8 @@ pub struct TensorDescription { pub shape: Vec, /// The [status](TensorStatus) of the tensor when it was used. pub status: TensorStatus, + /// The [type](DType) of the tensor. + pub dtype: DType, } impl TensorId { diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 35c0d49eeb..f1d38d049d 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -112,9 +112,10 @@ impl Distribution { /// # Returns /// /// The distribution sampler. - pub fn sampler(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> + pub fn sampler(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> where - E: rand::distributions::uniform::SampleUniform, + R: RngCore, + E: Element + rand::distributions::uniform::SampleUniform, Standard: rand::distributions::Distribution, { let kind = match self { diff --git a/crates/burn-tensor/src/tensor/element.rs b/crates/burn-tensor/src/tensor/element.rs index d8eaac94b6..7a2c99a9bc 100644 --- a/crates/burn-tensor/src/tensor/element.rs +++ b/crates/burn-tensor/src/tensor/element.rs @@ -4,6 +4,7 @@ use crate::Distribution; use half::{bf16, f16}; use num_traits::{identities::Zero, One, ToPrimitive}; use rand::RngCore; +use serde::{Deserialize, Serialize}; /// Element trait for tensor. pub trait Element: @@ -22,6 +23,8 @@ pub trait Element: + Copy + 'static { + /// The dtype of the element. + fn dtype() -> DType; } /// Element conversion trait for tensor. @@ -53,9 +56,7 @@ pub trait ElementRandom { /// # Returns /// /// The random value. - fn random(distribution: Distribution, rng: &mut R) -> Self - where - Self: Sized; + fn random(distribution: Distribution, rng: &mut R) -> Self; } /// Element ordering trait. @@ -93,10 +94,15 @@ macro_rules! make_element { ty $type:ident $precision:expr, convert $convert:expr, random $random:expr, - cmp $cmp:expr + cmp $cmp:expr, + dtype $dtype:expr ) => { - impl Element for $type {} + impl Element for $type { + fn dtype() -> $crate::DType { + $dtype + } + } impl ElementConversion for $type { fn from_elem(elem: E) -> Self { @@ -136,56 +142,64 @@ make_element!( ty f64 Precision::Double, convert |elem: &dyn ToPrimitive| elem.to_f64().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &f64, b: &f64| a.total_cmp(b) + cmp |a: &f64, b: &f64| a.total_cmp(b), + dtype DType::F64 ); make_element!( ty f32 Precision::Full, convert |elem: &dyn ToPrimitive| elem.to_f32().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &f32, b: &f32| a.total_cmp(b) + cmp |a: &f32, b: &f32| a.total_cmp(b), + dtype DType::F32 ); make_element!( ty i64 Precision::Double, convert |elem: &dyn ToPrimitive| elem.to_i64().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &i64, b: &i64| Ord::cmp(a, b) + cmp |a: &i64, b: &i64| Ord::cmp(a, b), + dtype DType::I64 ); make_element!( ty i32 Precision::Full, convert |elem: &dyn ToPrimitive| elem.to_i32().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &i32, b: &i32| Ord::cmp(a, b) + cmp |a: &i32, b: &i32| Ord::cmp(a, b), + dtype DType::I32 ); make_element!( ty u32 Precision::Full, convert |elem: &dyn ToPrimitive| elem.to_u32().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &u32, b: &u32| Ord::cmp(a, b) + cmp |a: &u32, b: &u32| Ord::cmp(a, b), + dtype DType::U32 ); make_element!( ty i16 Precision::Half, convert |elem: &dyn ToPrimitive| elem.to_i16().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &i16, b: &i16| Ord::cmp(a, b) + cmp |a: &i16, b: &i16| Ord::cmp(a, b), + dtype DType::I16 ); make_element!( ty i8 Precision::Other, convert |elem: &dyn ToPrimitive| elem.to_i8().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &i8, b: &i8| Ord::cmp(a, b) + cmp |a: &i8, b: &i8| Ord::cmp(a, b), + dtype DType::I8 ); make_element!( ty u8 Precision::Other, convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), - cmp |a: &u8, b: &u8| Ord::cmp(a, b) + cmp |a: &u8, b: &u8| Ord::cmp(a, b), + dtype DType::U8 ); make_element!( @@ -195,7 +209,8 @@ make_element!( let sample: f32 = distribution.sampler(rng).sample(); f16::from_elem(sample) }, - cmp |a: &f16, b: &f16| a.total_cmp(b) + cmp |a: &f16, b: &f16| a.total_cmp(b), + dtype DType::F16 ); make_element!( ty bf16 Precision::Half, @@ -204,5 +219,23 @@ make_element!( let sample: f32 = distribution.sampler(rng).sample(); bf16::from_elem(sample) }, - cmp |a: &bf16, b: &bf16| a.total_cmp(b) + cmp |a: &bf16, b: &bf16| a.total_cmp(b), + dtype DType::BF16 ); + +#[allow(missing_docs)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub enum DType { + F64, + F32, + F16, + BF16, + I64, + I32, + I16, + I8, + U64, + U32, + U8, + Bool, +} diff --git a/examples/custom-training-loop/src/lib.rs b/examples/custom-training-loop/src/lib.rs index 89b9bd80f9..5d01dfadd1 100644 --- a/examples/custom-training-loop/src/lib.rs +++ b/examples/custom-training-loop/src/lib.rs @@ -159,7 +159,7 @@ where #[allow(dead_code)] impl Learner2 { - pub fn step3(&mut self, _batch: MnistBatch) + pub fn step3(&mut self, _batch: MnistBatch) where B: AutodiffBackend, M: AutodiffModule, diff --git a/examples/text-classification/src/data/tokenizer.rs b/examples/text-classification/src/data/tokenizer.rs index 3d1044b365..4b1c8adee3 100644 --- a/examples/text-classification/src/data/tokenizer.rs +++ b/examples/text-classification/src/data/tokenizer.rs @@ -5,6 +5,7 @@ // This trait represents the common interface for all tokenizer types. // The `Send + Sync` bounds are necessary for allowing these operations // to work across thread boundaries. +#[allow(dead_code)] pub trait Tokenizer: Send + Sync { /// Converts a text string into a sequence of tokens. fn encode(&self, value: &str) -> Vec; diff --git a/examples/text-generation/src/data/tokenizer.rs b/examples/text-generation/src/data/tokenizer.rs index cf6fc81bae..53b294bc3f 100644 --- a/examples/text-generation/src/data/tokenizer.rs +++ b/examples/text-generation/src/data/tokenizer.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] pub trait Tokenizer: Send + Sync { fn encode(&self, value: &str, special_tokens: bool) -> Vec; fn decode(&self, tokens: &[usize]) -> String;