Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fusion] Support multi-precision fusion #1718

Merged
merged 27 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions crates/burn-common/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ impl<T> Reader<T> {
}

/// Map the current reader to another type.
pub fn map<O, F: FnOnce(T) -> O>(self, mapper: F) -> Reader<O>
pub fn map<O, F>(self, mapper: F) -> Reader<O>
where
T: 'static + Send,
O: 'static + Send,
F: 'static + Send,
F: FnOnce(T) -> O + 'static + Send,
{
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
return Reader::Async(Box::new(MappedReader::new(self, mapper)));
Expand Down
11 changes: 11 additions & 0 deletions crates/burn-compute/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ pub struct ComputeRuntime<Device, Server: ComputeServer, Channel> {
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
}

impl<Device, Server, Channel> Default for ComputeRuntime<Device, Server, Channel>
where
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
Server: ComputeServer,
Channel: ComputeChannel<Server>,
{
fn default() -> Self {
Self::new()
}
}

impl<Device, Server, Channel> ComputeRuntime<Device, Server, Channel>
where
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
Expand Down
12 changes: 6 additions & 6 deletions crates/burn-core/src/record/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,13 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
fn clone_unsafely<T>(thing: &T) -> T {
unsafe {
// Allocate memory for the clone.
let clone = ptr::null_mut();
// Correcting pointer usage based on feedback
let clone = ptr::addr_of_mut!(*clone);
let mut clone = std::mem::MaybeUninit::<T>::uninit();
// Get a mutable pointer to the allocated memory.
let clone_ptr = clone.as_mut_ptr();
// Copy the memory
ptr::copy_nonoverlapping(thing as *const T, clone, 1);
// Transmute the cloned data pointer into an owned instance of T.
ptr::read(clone)
ptr::copy_nonoverlapping(thing as *const T, clone_ptr, 1);
// Assume the cloned data is initialized and convert it to an owned instance of T.
clone.assume_init()
}
}

Expand Down
24 changes: 10 additions & 14 deletions crates/burn-core/src/record/serde/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ impl SerializerTrait for Serializer {
Ok(self)
}

fn serialize_newtype_struct<T: ?Sized>(
fn serialize_newtype_struct<T>(
self,
_name: &'static str,
value: &T,
) -> Result<Self::Ok, Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
value.serialize(self)
}
Expand Down Expand Up @@ -128,9 +128,9 @@ impl SerializerTrait for Serializer {
unimplemented!()
}

fn serialize_some<T: ?Sized>(self, value: &T) -> Result<Self::Ok, Self::Error>
fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
value.serialize(self)
}
Expand All @@ -152,15 +152,15 @@ impl SerializerTrait for Serializer {
unimplemented!()
}

fn serialize_newtype_variant<T: ?Sized>(
fn serialize_newtype_variant<T>(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
) -> Result<Self::Ok, Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
unimplemented!()
}
Expand Down Expand Up @@ -207,13 +207,9 @@ impl SerializeStruct for Serializer {
type Ok = NestedValue;
type Error = Error;

fn serialize_field<T: ?Sized>(
&mut self,
key: &'static str,
value: &T,
) -> Result<(), Self::Error>
fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
let serialized_value = value.serialize(Serializer::new())?;

Expand Down Expand Up @@ -248,9 +244,9 @@ impl SerializeSeq for Serializer {
type Ok = NestedValue;
type Error = Error;

fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<(), Self::Error>
fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
let serialized_value = value.serialize(Serializer::new())?;

Expand Down
74 changes: 54 additions & 20 deletions crates/burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use crate::{
client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge,
};
use burn_tensor::{
backend::Backend,
backend::{Backend, DeviceOps},
ops::FloatTensor,
repr::{OperationDescription, ReprBackend},
Device,
};
Expand All @@ -11,30 +12,30 @@ use std::marker::PhantomData;

pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();

pub(crate) fn get_client<B: FusionBackend>(device: &B::Device) -> B::FusionClient {
CLIENTS.client(device)
pub(crate) fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
CLIENTS.client::<B::FusionRuntime>(device)
}

/// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend).
#[derive(Clone, Debug, Default)]
pub struct Fusion<B> {
pub struct Fusion<B: FusionBackend> {
_backend: PhantomData<B>,
}

impl<B: FusionBackend> Backend for Fusion<B> {
type Device = B::Device;

type FullPrecisionBridge = PrecisionBridge;
type FullPrecisionBridge = PrecisionBridge<B::FullPrecisionBackend>;

type FloatTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;
type FloatTensorPrimitive<const D: usize> = FusionTensor<B::FusionRuntime>;

type FloatElem = B::FloatElem;

type IntTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;
type IntTensorPrimitive<const D: usize> = FusionTensor<B::FusionRuntime>;

type IntElem = B::IntElem;

type BoolTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;
type BoolTensorPrimitive<const D: usize> = FusionTensor<B::FusionRuntime>;

fn name() -> String {
format!("fusion<{}>", B::name())
Expand All @@ -45,10 +46,14 @@ impl<B: FusionBackend> Backend for Fusion<B> {
}

fn sync(device: &Self::Device) {
let client = CLIENTS.client::<B::FusionClient>(&device.clone());
let client = CLIENTS.client::<B::FusionRuntime>(&device.clone());
client.drain();
B::sync(device)
}

fn ad_enabled() -> bool {
false
}
}

/// The status of a [builder](OptimizationBuilder).
Expand Down Expand Up @@ -101,32 +106,61 @@ pub trait OptimizationBuilder<O>: Send {
}

/// The operation created from the [builder](OptimizationBuilder).
pub trait Optimization<B: FusionBackend>: Send {
pub trait Optimization<R: FusionRuntime>: Send {
/// Execute the operation.
fn execute(&mut self, context: &mut Context<'_, B>);
fn execute(&mut self, context: &mut Context<'_, R::FusionHandle>);
/// The number of registered operations in this optimization.
fn len(&self) -> usize;
/// If the current optimization is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns the state that can be serialized.
fn to_state(&self) -> B::OptimizationState;
fn to_state(&self) -> R::OptimizationState;
/// Create the optimization from the state.
fn from_state(device: &B::Device, state: B::OptimizationState) -> Self;
fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
}

/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
/// [operation builder](crate::OptimizationBuilder).
pub trait FusionBackend: Backend + ReprBackend {
/// Type alias for `<R as FusionRuntime>::FusionDevice`.
pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
/// Type alias for `<R as FusionRuntime>::FusionHandle`.
pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
/// Type alias for `<R as FusionRuntime>::FusionClient`.
pub type Client<R> = <R as FusionRuntime>::FusionClient;

/// Trait that defines a runtime that will benefits from fused operations.
pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug {
/// The state that can be serialized for an optimization.
type OptimizationState: Serialize + DeserializeOwned;
/// Optimization type for the backend.
type Optimization: Optimization<Self>;
/// What kind of client should be used.
type FusionClient: FusionClient<FusionBackend = Self>;
/// Handle used to store tensor dynamically.
type FusionHandle: Clone + Send;
/// Device used by the runtime.
type FusionDevice: DeviceOps;
/// The client to interact with the runtime.
type FusionClient: FusionClient<Self>;

/// The list of optimizations that will be used to optimize the computational graph.
fn optimizations(device: Device<Self>)
-> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
fn optimizations(
device: Self::FusionDevice,
) -> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
}

/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
/// [operation builder](crate::OptimizationBuilder).
pub trait FusionBackend:
ReprBackend<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
{
/// The runtime used for this backend.
type FusionRuntime: FusionRuntime;

/// Cast a float tensor and returns the resulting handle.
fn cast_float<const D: usize>(
tensor: FloatTensor<Self, D>,
dtype: burn_tensor::DType,
) -> Self::Handle;

/// Pointer to the full precision fusion backend.
type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
}
95 changes: 82 additions & 13 deletions crates/burn-fusion/src/bridge.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,94 @@
use burn_tensor::backend::BackendBridge;

use crate::{Fusion, FusionBackend};
use crate::{
client::FusionClient, stream::execution::Operation, Fusion, FusionBackend, FusionRuntime,
};
use burn_tensor::{
backend::BackendBridge,
ops::FloatTensor,
repr::{
BaseOperationDescription, HandleContainer, OperationDescription, UnaryOperationDescription,
},
Element,
};
use std::marker::PhantomData;

#[derive(Debug)]
/// Fusion bridge.
pub struct PrecisionBridge;
pub struct PrecisionBridge<B: FusionBackend> {
_backend: PhantomData<B>,
}

impl<B: FusionBackend> BackendBridge<Fusion<B>> for PrecisionBridge {
type Target = Fusion<B>;
impl<R, BInput, BTarget> BackendBridge<Fusion<BInput>> for PrecisionBridge<BTarget>
where
BInput: FusionBackend<FusionRuntime = R>,
BTarget: FusionBackend<FusionRuntime = R>,
R: FusionRuntime + 'static,
{
type Target = Fusion<BTarget>;

fn into_target<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Fusion<B>, D>,
tensor: FloatTensor<Fusion<BInput>, D>,
_device: Option<burn_tensor::Device<Self::Target>>,
) -> burn_tensor::ops::FloatTensor<Self::Target, D> {
tensor
) -> FloatTensor<Self::Target, D> {
cast::<R, BInput, BTarget, D>(tensor)
}

fn from_target<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Self::Target, D>,
_device: Option<burn_tensor::Device<Fusion<B>>>,
) -> burn_tensor::ops::FloatTensor<Fusion<B>, D> {
tensor
tensor: FloatTensor<Self::Target, D>,
_device: Option<burn_tensor::Device<Fusion<BInput>>>,
) -> FloatTensor<Fusion<BInput>, D> {
cast::<R, BTarget, BInput, D>(tensor)
}
}

fn cast<R, BInput, BTarget, const D: usize>(
input: FloatTensor<Fusion<BInput>, D>,
) -> FloatTensor<Fusion<BTarget>, D>
where
BInput: FusionBackend<FusionRuntime = R>,
BTarget: FusionBackend<FusionRuntime = R>,
R: FusionRuntime + 'static,
{
#[derive(new)]
struct Cast<R: FusionRuntime, BInput: FusionBackend, BTarget: FusionBackend, const D: usize> {
desc: UnaryOperationDescription,
_bi: PhantomData<BInput>,
_bt: PhantomData<BTarget>,
_runtime: PhantomData<R>,
}

impl<const D: usize, R, BInput, BTarget> Operation<BTarget::FusionRuntime>
for Cast<R, BInput, BTarget, D>
where
BInput: FusionBackend<FusionRuntime = R>,
BTarget: FusionBackend<FusionRuntime = R>,
R: FusionRuntime,
{
fn execute(
self: Box<Self>,
handles: &mut HandleContainer<<BTarget::FusionRuntime as FusionRuntime>::FusionHandle>,
) {
let input = handles.get_float_tensor::<BInput, D>(&self.desc.input);
let output = BInput::cast_float(input, BTarget::FloatElem::dtype());

handles.register_handle(self.desc.out.id, output);
}
}

let stream = input.stream;
let out = input
.client
.tensor_uninitialized(input.shape.clone(), BTarget::FloatElem::dtype());

let desc = UnaryOperationDescription {
input: input.into_description(),
out: out.to_description_out(),
};

out.client.register(
vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::Cast(desc.clone())),
Cast::<R, BInput, BTarget, D>::new(desc),
);

out
}
Loading
Loading