Skip to content

Commit

Permalink
Refactor/burn ir (#2796)
Browse files Browse the repository at this point in the history
* Move IR to its own crate

* Rename IR types to Repr

* Rename description to (intermediate) repr

* Fix deps

* Add publish workflow

* Fix README

* Change to Ir suffix

* Cargo fmt

* Fix gather op
  • Loading branch information
laggui authored Feb 11, 2025
1 parent ffe8cd7 commit 5b68f0a
Show file tree
Hide file tree
Showing 76 changed files with 5,122 additions and 5,760 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
with:
crate: burn-router
needs:
- publish-burn-ir
- publish-burn-common
- publish-burn-tensor
# dev dependencies
Expand All @@ -44,6 +45,7 @@ jobs:
with:
crate: burn-remote
needs:
- publish-burn-ir
- publish-burn-common
- publish-burn-tensor
- publish-burn-router
Expand Down Expand Up @@ -90,9 +92,19 @@ jobs:
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

publish-burn-ir:
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-tensor
with:
crate: burn-ir
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

publish-burn-fusion:
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-ir
- publish-burn-tensor
- publish-burn-common
with:
Expand All @@ -103,6 +115,7 @@ jobs:
publish-burn-jit:
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-ir
- publish-burn-common
- publish-burn-fusion
- publish-burn-tensor
Expand Down Expand Up @@ -136,6 +149,7 @@ jobs:
publish-burn-ndarray:
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-ir
- publish-burn-tensor
- publish-burn-autodiff
- publish-burn-common
Expand Down
16 changes: 15 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions contributor-book/src/guides/adding-a-new-operation-to-burn.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ Here's how powf was added to `burn-fusion`:

1. Added powf to the float ops under
[`crates/burn-fusion/src/ops/float.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-fusion/src/ops/float.rs#L1838)
2. Added powf to the `NumericOperationDescription` enum under
2. Added powf to the `NumericOperationIr` enum under
[crates/burn-fusion/src/stream/operation.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-fusion/src/stream/operation.rs#L433)
3. Added powf to the implementations of `NumericOperationDescription` enum under
3. Added powf to the implementations of `NumericOperationIr` enum under
[crates/burn-fusion/src/stream/context.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-fusion/src/stream/context.rs#L771)

The way `cubecl` handles tensor-scalar operations is by transforming both into a sequence of
Expand Down
1 change: 1 addition & 0 deletions crates/burn-fusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ doc = ["default"]
[dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.17.0" }
burn-common = { path = "../burn-common", version = "0.17.0" }
burn-ir = { path = "../burn-ir", version = "0.17.0" }
hashbrown = { workspace = true }
derive-new = {workspace = true }
spin = { workspace = true }
Expand Down
14 changes: 7 additions & 7 deletions crates/burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{client::FusionClient, stream::Context, FusionClientLocator, FusionTensor};
use burn_ir::{BackendIr, OperationIr, TensorHandle};
use burn_tensor::{
backend::{Backend, DeviceOps},
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
repr::{OperationDescription, ReprBackend, TensorHandle},
Device, Element,
};
use serde::{de::DeserializeOwned, Serialize};
Expand Down Expand Up @@ -77,7 +77,7 @@ pub struct OptimizationProperties {
}

/// The fusion operation abstraction allows implementations to fuse many
/// [tensor operations](OperationDescription) into one, improving the performance of the backend.
/// [tensor operations](OperationIr) into one, improving the performance of the backend.
///
///
/// # Notes
Expand All @@ -89,8 +89,8 @@ pub struct OptimizationProperties {
/// Also, it is important to return (OptimizationStatus::Closed) when no more registered operation can
/// improve the performance.
pub trait OptimizationBuilder<O>: Send {
/// Register a new [tensor operation](OperationDescription).
fn register(&mut self, operation: &OperationDescription);
/// Register a new [tensor operation](OperationIr).
fn register(&mut self, operation: &OperationIr);
/// Finish the optimization and create a fusion operation.
fn build(&self) -> O;
/// Reset the state.
Expand Down Expand Up @@ -154,7 +154,7 @@ pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug {
/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
/// [operation builder](crate::OptimizationBuilder).
pub trait FusionBackend:
ReprBackend<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
BackendIr<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
{
/// The runtime used for this backend.
type FusionRuntime: FusionRuntime;
Expand All @@ -166,8 +166,8 @@ pub trait FusionBackend:
type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
}

// Fusion implements `ReprBackend` to enable router backend usage.
impl<B: FusionBackend> ReprBackend for Fusion<B> {
// Fusion implements `BackendIr` to enable router backend usage.
impl<B: FusionBackend> BackendIr for Fusion<B> {
type Handle = FusionTensor<B::FusionRuntime>;

fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
Expand Down
26 changes: 12 additions & 14 deletions crates/burn-fusion/src/client/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ use crate::{
stream::{execution::Operation, StreamId},
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor,
};
use burn_tensor::{
repr::{OperationDescription, TensorDescription, TensorId},
DType, TensorData,
};
use burn_ir::{OperationIr, TensorId, TensorIr};
use burn_tensor::{DType, TensorData};

/// Define how to interact with the fusion server.
pub trait FusionClient<R>: Send + Sync + Clone + Sized
Expand All @@ -16,8 +14,8 @@ where
{
/// Create a new client for the given [device](FusionRuntime::FusionDevice).
fn new(device: FusionDevice<R>) -> Self;
/// Register a new [tensor operation description](OperationDescription).
fn register<O>(&self, streams: Vec<StreamId>, description: OperationDescription, operation: O)
/// Register a new [tensor operation intermediate representation](OperationIr).
fn register<O>(&self, streams: Vec<StreamId>, repr: OperationIr, operation: O)
where
O: Operation<R> + 'static;
/// Register all lazy computation.
Expand All @@ -37,31 +35,31 @@ where
/// Read the values contained by a float tensor.
fn read_tensor_float<B>(
self,
tensor: TensorDescription,
tensor: TensorIr,
stream: StreamId,
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by an int tensor.
fn read_tensor_int<B>(
self,
tensor: TensorDescription,
tensor: TensorIr,
stream: StreamId,
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by a bool tensor.
fn read_tensor_bool<B>(
self,
tensor: TensorDescription,
tensor: TensorIr,
stream: StreamId,
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by a quantized tensor.
fn read_tensor_quantized<B>(
self,
tensor: TensorDescription,
tensor: TensorIr,
streams: StreamId,
) -> impl Future<Output = TensorData> + Send + 'static
where
Expand All @@ -81,7 +79,7 @@ where
/// Change the client of the given float tensor.
fn change_client_float<B>(
&self,
tensor: TensorDescription,
tensor: TensorIr,
client: Self,
stream: StreamId,
) -> FusionTensor<R>
Expand All @@ -90,7 +88,7 @@ where
/// Change the client of the given int tensor.
fn change_client_int<B>(
&self,
tensor: TensorDescription,
tensor: TensorIr,
client: Self,
stream: StreamId,
) -> FusionTensor<R>
Expand All @@ -99,7 +97,7 @@ where
/// Change the client of the given bool tensor.
fn change_client_bool<B>(
&self,
tensor: TensorDescription,
tensor: TensorIr,
client: Self,
stream: StreamId,
) -> FusionTensor<R>
Expand All @@ -108,7 +106,7 @@ where
/// Change the client of the given quantized tensor.
fn change_client_quantized<B>(
&self,
tensor: TensorDescription,
tensor: TensorIr,
client: Self,
stream: StreamId,
) -> FusionTensor<R>
Expand Down
Loading

0 comments on commit 5b68f0a

Please sign in to comment.