Skip to content

Commit

Permalink
Refactor burn jit => burn-cubecl (#2809)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Feb 13, 2025
1 parent 3ed38cf commit d9e4146
Show file tree
Hide file tree
Showing 109 changed files with 1,082 additions and 1,035 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- publish-burn-autodiff
- publish-burn-candle
- publish-burn-fusion
- publish-burn-jit
- publish-burn-cubecl
- publish-burn-ndarray
- publish-burn-tch
- publish-burn-tensor
Expand Down Expand Up @@ -113,7 +113,7 @@ jobs:
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

publish-burn-jit:
publish-burn-cubecl:
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs:
- publish-burn-ir
Expand All @@ -122,7 +122,7 @@ jobs:
- publish-burn-tensor
- publish-burn-ndarray
with:
crate: burn-jit
crate: burn-cubecl
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

Expand Down Expand Up @@ -166,7 +166,7 @@ jobs:
- publish-burn-autodiff
- publish-burn-ndarray
- publish-burn-common
- publish-burn-jit
- publish-burn-cubecl
with:
crate: burn-wgpu
secrets:
Expand All @@ -179,7 +179,7 @@ jobs:
- publish-burn-autodiff
- publish-burn-ndarray
- publish-burn-common
- publish-burn-jit
- publish-burn-cubecl
with:
crate: burn-cuda
secrets:
Expand All @@ -192,7 +192,7 @@ jobs:
- publish-burn-autodiff
- publish-burn-ndarray
- publish-burn-common
- publish-burn-jit
- publish-burn-cubecl
with:
crate: burn-hip
secrets:
Expand Down
10 changes: 5 additions & 5 deletions burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ kernel. We'll go into implementing our custom backend trait for the generic JIT
automatically implements the trait for `burn-cuda`, `burn-wgpu` as well as fusion.

```rust, ignore
/// Implement our custom backend trait for the generic `JitBackend`.
impl<R: JitRuntime, F: FloatElement, I: IntElement> Backend for JitBackend<R, F, I> {
/// Implement our custom backend trait for the generic `CubeBackend`.
impl<R: CubeRuntime, F: FloatElement, I: IntElement> Backend for CubeBackend<R, F, I> {
fn fused_matmul_add_relu(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
Expand Down Expand Up @@ -172,7 +172,7 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> Backend for JitBackend<R, F,
// Create the output tensor primitive.
// Create the output tensor primitive.
let output = JitTensor::new_contiguous(
let output = CubeTensor::new_contiguous(
lhs.client.clone(),
lhs.device.clone(),
shape_out,
Expand Down Expand Up @@ -362,8 +362,8 @@ operation nodes.
The only remaining part is to implement our autodiff-decorated backend trait for our JIT Backend.

```rust, ignore
impl<R: JitRuntime, F: FloatElement, I: IntElement> AutodiffBackend
for Autodiff<JitBackend<R, F, I>>
impl<R: CubeRuntime, F: FloatElement, I: IntElement> AutodiffBackend
for Autodiff<CubeBackend<R, F, I>>
{
}
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ the raw `WgpuBackend` type.

```rust, ignore
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
impl<F: FloatElement, I: IntElement> Backend for CubeBackend<WgpuRuntime, F, I> {
fn fused_matmul_add_relu(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
Expand Down Expand Up @@ -239,7 +239,7 @@ impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
.empty(shape_out.num_elements() * core::mem::size_of::<F>());
// Create the output tensor primitive.
let output = JitTensor::new_contiguous(
let output = CubeTensor::new_contiguous(
lhs.client.clone(),
lhs.device.clone(),
shape_out,
Expand Down
44 changes: 22 additions & 22 deletions crates/burn-cubecl/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{element::BoolElement, tensor::JitTensor, FloatElement, IntElement, JitRuntime};
use crate::{element::BoolElement, tensor::CubeTensor, CubeRuntime, FloatElement, IntElement};
use burn_tensor::backend::{Backend, DeviceOps};
use cubecl::server::ComputeServer;
use rand::{rngs::StdRng, SeedableRng};
Expand All @@ -13,16 +13,16 @@ pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);

/// Generic tensor backend that can be compiled just-in-time to any shader runtime
#[derive(new)]
pub struct JitBackend<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
pub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
_runtime: PhantomData<R>,
_float_elem: PhantomData<F>,
_int_elem: PhantomData<I>,
_bool_elem: PhantomData<BT>,
}

impl<R, F, I, BT> Backend for JitBackend<R, F, I, BT>
impl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>
where
R: JitRuntime,
R: CubeRuntime,
R::Server: ComputeServer,
R::Device: burn_tensor::backend::DeviceOps,
F: FloatElement,
Expand All @@ -35,14 +35,14 @@ where
type IntElem = I;
type BoolElem = BT;

type FloatTensorPrimitive = JitTensor<R>;
type IntTensorPrimitive = JitTensor<R>;
type BoolTensorPrimitive = JitTensor<R>;
type QuantizedTensorPrimitive = JitTensor<R>;
type FloatTensorPrimitive = CubeTensor<R>;
type IntTensorPrimitive = CubeTensor<R>;
type BoolTensorPrimitive = CubeTensor<R>;
type QuantizedTensorPrimitive = CubeTensor<R>;
type QuantizedEncoding = u32;

fn name() -> String {
format!("jit<{}>", R::name())
format!("cubecl<{}>", R::name())
}

fn seed(seed: u64) {
Expand All @@ -61,43 +61,43 @@ where
}
}

impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
for JitBackend<R, F, I, BT>
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
for CubeBackend<R, F, I, BT>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("JitBackend {{ runtime: {}}}", R::name()))
f.write_fmt(format_args!("CubeBackend {{ runtime: {}}}", R::name()))
}
}

impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
for JitBackend<R, F, I, BT>
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
for CubeBackend<R, F, I, BT>
{
fn clone(&self) -> Self {
Self::new()
}
}

impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
for JitBackend<R, F, I, BT>
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
for CubeBackend<R, F, I, BT>
{
fn default() -> Self {
Self::new()
}
}

impl<R: cubecl::Runtime> JitRuntime for R
impl<R: cubecl::Runtime> CubeRuntime for R
where
R::Device: DeviceOps,
{
type JitDevice = R::Device;
type JitServer = R::Server;
type CubeDevice = R::Device;
type CubeServer = R::Server;
}

#[cfg(not(feature = "fusion"))]
impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
for JitBackend<R, F, I, BT>
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
for CubeBackend<R, F, I, BT>
{
type Handle = JitTensor<R>;
type Handle = CubeTensor<R>;

fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
handle.handle
Expand Down
36 changes: 18 additions & 18 deletions crates/burn-cubecl/src/element.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
use cubecl::{
flex32,
prelude::{Float, Int, Numeric},
CubeElement,
CubeElement as CubeElem,
};

/// The base element trait for the jit backend.
pub trait JitElement: burn_tensor::Element + CubeElement + PartialEq + Numeric {}
pub trait CubeElement: burn_tensor::Element + CubeElem + PartialEq + Numeric {}

/// The float element type for the jit backend.
pub trait FloatElement: JitElement + Float {}
pub trait FloatElement: CubeElement + Float {}

/// The int element type for the jit backend.
pub trait IntElement: JitElement + Int {}
pub trait IntElement: CubeElement + Int {}

/// The element type for booleans for the jit backend.
pub trait BoolElement: JitElement + Int {
pub trait BoolElement: CubeElement + Int {
/// The true value for the boolean element.
fn true_val() -> Self {
Self::from_int(1)
Expand All @@ -34,19 +34,19 @@ pub trait BoolElement: JitElement + Int {
}
}

impl JitElement for u64 {}
impl JitElement for u32 {}
impl JitElement for u16 {}
impl JitElement for u8 {}
impl JitElement for i64 {}
impl JitElement for i32 {}
impl JitElement for i16 {}
impl JitElement for i8 {}
impl JitElement for f64 {}
impl JitElement for f32 {}
impl JitElement for flex32 {}
impl JitElement for half::f16 {}
impl JitElement for half::bf16 {}
impl CubeElement for u64 {}
impl CubeElement for u32 {}
impl CubeElement for u16 {}
impl CubeElement for u8 {}
impl CubeElement for i64 {}
impl CubeElement for i32 {}
impl CubeElement for i16 {}
impl CubeElement for i8 {}
impl CubeElement for f64 {}
impl CubeElement for f32 {}
impl CubeElement for flex32 {}
impl CubeElement for half::f16 {}
impl CubeElement for half::bf16 {}

impl FloatElement for f64 {}
impl FloatElement for f32 {}
Expand Down
Loading

0 comments on commit d9e4146

Please sign in to comment.