From d9e41460ffb5cabcac7bbaa11246093545d26cf8 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Thu, 13 Feb 2025 12:39:29 -0500 Subject: [PATCH] Refactor burn jit => burn-cubecl (#2809) --- .github/workflows/publish.yml | 12 +- .../backend-extension/custom-cubecl-kernel.md | 10 +- .../backend-extension/custom-wgpu-kernel.md | 4 +- crates/burn-cubecl/src/backend.rs | 44 +++---- crates/burn-cubecl/src/element.rs | 36 +++--- crates/burn-cubecl/src/fusion/base.rs | 80 ++++++------ .../src/fusion/elemwise/builder.rs | 14 +-- .../src/fusion/elemwise/optimization.rs | 10 +- .../burn-cubecl/src/fusion/matmul/builder.rs | 14 +-- .../src/fusion/matmul/optimization.rs | 27 ++-- crates/burn-cubecl/src/fusion/matmul/tune.rs | 24 ++-- .../src/fusion/on_write/trace/base.rs | 10 +- .../src/fusion/on_write/trace/executor.rs | 14 +-- .../src/fusion/on_write/trace/input.rs | 10 +- .../src/fusion/on_write/trace/output.rs | 24 ++-- .../src/fusion/on_write/trace/plan.rs | 16 +-- .../src/fusion/on_write/trace/runner.rs | 14 +-- .../fusion/on_write/trace/vectorization.rs | 10 +- crates/burn-cubecl/src/fusion/tune.rs | 36 +++--- crates/burn-cubecl/src/kernel/binary.rs | 16 +-- crates/burn-cubecl/src/kernel/binary_int.rs | 16 +-- crates/burn-cubecl/src/kernel/cast/base.rs | 10 +- .../burn-cubecl/src/kernel/cast/bool_cast.rs | 10 +- crates/burn-cubecl/src/kernel/clamp.rs | 12 +- crates/burn-cubecl/src/kernel/comparison.rs | 94 +++++++------- crates/burn-cubecl/src/kernel/contiguous.rs | 6 +- .../src/kernel/conv/conv2d/base.rs | 22 ++-- .../src/kernel/conv/conv2d/col2im.rs | 37 +++--- .../src/kernel/conv/conv2d/direct.rs | 14 +-- .../src/kernel/conv/conv2d/gemm/launch.rs | 48 +++---- .../src/kernel/conv/conv2d/gemm/selection.rs | 9 +- .../src/kernel/conv/conv2d/im2col.rs | 38 +++--- .../src/kernel/conv/conv2d/implicit_gemm.rs | 20 +-- .../src/kernel/conv/conv2d/layout_swap.rs | 6 +- .../kernel/conv/conv2d/transpose_direct.rs | 16 +-- .../src/kernel/conv/conv2d/tune/conv2d.rs | 48 +++---- .../conv/conv2d/tune/conv_transpose2d.rs | 48 +++---- crates/burn-cubecl/src/kernel/conv/conv3d.rs | 14 +-- .../src/kernel/conv/conv_transpose3d.rs | 16 +-- .../src/kernel/conv/deform_conv2d.rs | 28 ++--- .../kernel/conv/deform_conv_transpose2d.rs | 74 +++++------ crates/burn-cubecl/src/kernel/index/flip.rs | 16 +-- crates/burn-cubecl/src/kernel/index/gather.rs | 10 +- .../src/kernel/index/repeat_dim.rs | 8 +- .../burn-cubecl/src/kernel/index/scatter.rs | 16 +-- crates/burn-cubecl/src/kernel/index/select.rs | 10 +- .../src/kernel/index/select_assign.rs | 12 +- crates/burn-cubecl/src/kernel/index/slice.rs | 18 +-- .../src/kernel/index/slice_assign.rs | 10 +- .../src/kernel/interpolate/base.rs | 20 +-- .../src/kernel/interpolate/bicubic.rs | 10 +- .../src/kernel/interpolate/bilinear.rs | 10 +- .../src/kernel/interpolate/nearest.rs | 10 +- .../kernel/interpolate/nearest_backward.rs | 10 +- crates/burn-cubecl/src/kernel/mask/base.rs | 20 +-- .../burn-cubecl/src/kernel/mask/mask_fill.rs | 30 ++--- .../burn-cubecl/src/kernel/mask/mask_where.rs | 36 +++--- crates/burn-cubecl/src/kernel/matmul/base.rs | 12 +- .../src/kernel/matmul/tune/base.rs | 56 ++++----- .../burn-cubecl/src/kernel/matmul/tune/key.rs | 14 +-- crates/burn-cubecl/src/kernel/matmul/utils.rs | 12 +- .../src/kernel/pool/adaptive_avg_pool2d.rs | 8 +- .../pool/adaptive_avg_pool2d_backward.rs | 12 +- .../burn-cubecl/src/kernel/pool/avg_pool2d.rs | 8 +- .../src/kernel/pool/avg_pool2d_backward.rs | 12 +- .../burn-cubecl/src/kernel/pool/max_pool2d.rs | 14 +-- .../src/kernel/pool/max_pool2d_backward.rs | 14 +-- crates/burn-cubecl/src/kernel/prng/base.rs | 12 +- .../burn-cubecl/src/kernel/prng/bernoulli.rs | 12 +- crates/burn-cubecl/src/kernel/prng/normal.rs | 12 +- crates/burn-cubecl/src/kernel/prng/uniform.rs | 18 +-- .../src/kernel/quantization/dequantize.rs | 16 +-- .../src/kernel/quantization/quantize.rs | 30 ++--- crates/burn-cubecl/src/kernel/reduce/base.rs | 33 ++--- crates/burn-cubecl/src/kernel/reduce/tune.rs | 118 +++++++++--------- crates/burn-cubecl/src/kernel/unary_float.rs | 12 +- crates/burn-cubecl/src/kernel/unary_int.rs | 10 +- .../burn-cubecl/src/kernel/unary_numeric.rs | 11 +- crates/burn-cubecl/src/lib.rs | 20 +-- crates/burn-cubecl/src/ops/activation_ops.rs | 6 +- crates/burn-cubecl/src/ops/base.rs | 45 +++---- crates/burn-cubecl/src/ops/bool_ops.rs | 6 +- crates/burn-cubecl/src/ops/float_ops.rs | 8 +- crates/burn-cubecl/src/ops/int_ops.rs | 6 +- crates/burn-cubecl/src/ops/module_ops.rs | 6 +- crates/burn-cubecl/src/ops/numeric.rs | 115 ++++++++++------- crates/burn-cubecl/src/ops/qtensor.rs | 14 +-- crates/burn-cubecl/src/ops/transaction.rs | 6 +- crates/burn-cubecl/src/template/base.rs | 4 +- crates/burn-cubecl/src/tensor/base.rs | 34 ++--- crates/burn-cubecl/src/tests/mod.rs | 12 +- crates/burn-cubecl/src/tune_key.rs | 16 +-- crates/burn-cuda/src/lib.rs | 8 +- crates/burn-hip/src/lib.rs | 8 +- crates/burn-vision/Cargo.toml | 10 +- .../hardware_accelerated.rs | 16 +-- .../{jit => cube}/connected_components/mod.rs | 12 +- .../connected_components/prefix_sum.rs | 6 +- .../src/backends/{jit => cube}/mod.rs | 0 .../src/backends/{jit => cube}/ops.rs | 6 +- crates/burn-vision/src/backends/mod.rs | 4 +- crates/burn-wgpu/src/lib.rs | 8 +- .../examples/custom-cubecl-kernel.rs | 2 +- examples/custom-cubecl-kernel/src/backward.rs | 6 +- examples/custom-cubecl-kernel/src/forward.rs | 12 +- .../examples/custom-wgpu-kernel.rs | 2 +- examples/custom-wgpu-kernel/src/backward.rs | 4 +- examples/custom-wgpu-kernel/src/forward.rs | 8 +- examples/modern-lstm/README.md | 4 +- 109 files changed, 1082 insertions(+), 1035 deletions(-) rename crates/burn-vision/src/backends/{jit => cube}/connected_components/hardware_accelerated.rs (97%) rename crates/burn-vision/src/backends/{jit => cube}/connected_components/mod.rs (87%) rename crates/burn-vision/src/backends/{jit => cube}/connected_components/prefix_sum.rs (98%) rename crates/burn-vision/src/backends/{jit => cube}/mod.rs (100%) rename crates/burn-vision/src/backends/{jit => cube}/ops.rs (97%) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index fc9a76594e..c37fd9684d 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -14,7 +14,7 @@ jobs: - publish-burn-autodiff - publish-burn-candle - publish-burn-fusion - - publish-burn-jit + - publish-burn-cubecl - publish-burn-ndarray - publish-burn-tch - publish-burn-tensor @@ -113,7 +113,7 @@ jobs: secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} - publish-burn-jit: + publish-burn-cubecl: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 needs: - publish-burn-ir @@ -122,7 +122,7 @@ jobs: - publish-burn-tensor - publish-burn-ndarray with: - crate: burn-jit + crate: burn-cubecl secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} @@ -166,7 +166,7 @@ jobs: - publish-burn-autodiff - publish-burn-ndarray - publish-burn-common - - publish-burn-jit + - publish-burn-cubecl with: crate: burn-wgpu secrets: @@ -179,7 +179,7 @@ jobs: - publish-burn-autodiff - publish-burn-ndarray - publish-burn-common - - publish-burn-jit + - publish-burn-cubecl with: crate: burn-cuda secrets: @@ -192,7 +192,7 @@ jobs: - publish-burn-autodiff - publish-burn-ndarray - publish-burn-common - - publish-burn-jit + - publish-burn-cubecl with: crate: burn-hip secrets: diff --git a/burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md b/burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md index 4dad8cb7e9..2ef3a6f19e 100644 --- a/burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md +++ b/burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md @@ -131,8 +131,8 @@ kernel. We'll go into implementing our custom backend trait for the generic JIT automatically implements the trait for `burn-cuda`, `burn-wgpu` as well as fusion. ```rust, ignore -/// Implement our custom backend trait for the generic `JitBackend`. -impl Backend for JitBackend { +/// Implement our custom backend trait for the generic `CubeBackend`. +impl Backend for CubeBackend { fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, @@ -172,7 +172,7 @@ impl Backend for JitBackend AutodiffBackend - for Autodiff> +impl AutodiffBackend + for Autodiff> { } ``` diff --git a/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md b/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md index 69ca45ed97..f6b3db1291 100644 --- a/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md +++ b/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md @@ -200,7 +200,7 @@ the raw `WgpuBackend` type. ```rust, ignore /// Implement our custom backend trait for the existing backend `WgpuBackend`. -impl Backend for JitBackend { +impl Backend for CubeBackend { fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, @@ -239,7 +239,7 @@ impl Backend for JitBackend { .empty(shape_out.num_elements() * core::mem::size_of::()); // Create the output tensor primitive. - let output = JitTensor::new_contiguous( + let output = CubeTensor::new_contiguous( lhs.client.clone(), lhs.device.clone(), shape_out, diff --git a/crates/burn-cubecl/src/backend.rs b/crates/burn-cubecl/src/backend.rs index 36bc3e5eef..f9eb685990 100644 --- a/crates/burn-cubecl/src/backend.rs +++ b/crates/burn-cubecl/src/backend.rs @@ -1,4 +1,4 @@ -use crate::{element::BoolElement, tensor::JitTensor, FloatElement, IntElement, JitRuntime}; +use crate::{element::BoolElement, tensor::CubeTensor, CubeRuntime, FloatElement, IntElement}; use burn_tensor::backend::{Backend, DeviceOps}; use cubecl::server::ComputeServer; use rand::{rngs::StdRng, SeedableRng}; @@ -13,16 +13,16 @@ pub(crate) static SEED: Mutex> = Mutex::new(None); /// Generic tensor backend that can be compiled just-in-time to any shader runtime #[derive(new)] -pub struct JitBackend { +pub struct CubeBackend { _runtime: PhantomData, _float_elem: PhantomData, _int_elem: PhantomData, _bool_elem: PhantomData, } -impl Backend for JitBackend +impl Backend for CubeBackend where - R: JitRuntime, + R: CubeRuntime, R::Server: ComputeServer, R::Device: burn_tensor::backend::DeviceOps, F: FloatElement, @@ -35,14 +35,14 @@ where type IntElem = I; type BoolElem = BT; - type FloatTensorPrimitive = JitTensor; - type IntTensorPrimitive = JitTensor; - type BoolTensorPrimitive = JitTensor; - type QuantizedTensorPrimitive = JitTensor; + type FloatTensorPrimitive = CubeTensor; + type IntTensorPrimitive = CubeTensor; + type BoolTensorPrimitive = CubeTensor; + type QuantizedTensorPrimitive = CubeTensor; type QuantizedEncoding = u32; fn name() -> String { - format!("jit<{}>", R::name()) + format!("cubecl<{}>", R::name()) } fn seed(seed: u64) { @@ -61,43 +61,43 @@ where } } -impl core::fmt::Debug - for JitBackend +impl core::fmt::Debug + for CubeBackend { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!("JitBackend {{ runtime: {}}}", R::name())) + f.write_fmt(format_args!("CubeBackend {{ runtime: {}}}", R::name())) } } -impl Clone - for JitBackend +impl Clone + for CubeBackend { fn clone(&self) -> Self { Self::new() } } -impl Default - for JitBackend +impl Default + for CubeBackend { fn default() -> Self { Self::new() } } -impl JitRuntime for R +impl CubeRuntime for R where R::Device: DeviceOps, { - type JitDevice = R::Device; - type JitServer = R::Server; + type CubeDevice = R::Device; + type CubeServer = R::Server; } #[cfg(not(feature = "fusion"))] -impl BackendIr - for JitBackend +impl BackendIr + for CubeBackend { - type Handle = JitTensor; + type Handle = CubeTensor; fn float_tensor(handle: TensorHandle) -> FloatTensor { handle.handle diff --git a/crates/burn-cubecl/src/element.rs b/crates/burn-cubecl/src/element.rs index a1bbab7f5f..92398e4972 100644 --- a/crates/burn-cubecl/src/element.rs +++ b/crates/burn-cubecl/src/element.rs @@ -1,20 +1,20 @@ use cubecl::{ flex32, prelude::{Float, Int, Numeric}, - CubeElement, + CubeElement as CubeElem, }; /// The base element trait for the jit backend. -pub trait JitElement: burn_tensor::Element + CubeElement + PartialEq + Numeric {} +pub trait CubeElement: burn_tensor::Element + CubeElem + PartialEq + Numeric {} /// The float element type for the jit backend. -pub trait FloatElement: JitElement + Float {} +pub trait FloatElement: CubeElement + Float {} /// The int element type for the jit backend. -pub trait IntElement: JitElement + Int {} +pub trait IntElement: CubeElement + Int {} /// The element type for booleans for the jit backend. -pub trait BoolElement: JitElement + Int { +pub trait BoolElement: CubeElement + Int { /// The true value for the boolean element. fn true_val() -> Self { Self::from_int(1) @@ -34,19 +34,19 @@ pub trait BoolElement: JitElement + Int { } } -impl JitElement for u64 {} -impl JitElement for u32 {} -impl JitElement for u16 {} -impl JitElement for u8 {} -impl JitElement for i64 {} -impl JitElement for i32 {} -impl JitElement for i16 {} -impl JitElement for i8 {} -impl JitElement for f64 {} -impl JitElement for f32 {} -impl JitElement for flex32 {} -impl JitElement for half::f16 {} -impl JitElement for half::bf16 {} +impl CubeElement for u64 {} +impl CubeElement for u32 {} +impl CubeElement for u16 {} +impl CubeElement for u8 {} +impl CubeElement for i64 {} +impl CubeElement for i32 {} +impl CubeElement for i16 {} +impl CubeElement for i8 {} +impl CubeElement for f64 {} +impl CubeElement for f32 {} +impl CubeElement for flex32 {} +impl CubeElement for half::f16 {} +impl CubeElement for half::bf16 {} impl FloatElement for f64 {} impl FloatElement for f32 {} diff --git a/crates/burn-cubecl/src/fusion/base.rs b/crates/burn-cubecl/src/fusion/base.rs index 625ac7ded4..3d7a0b5221 100644 --- a/crates/burn-cubecl/src/fusion/base.rs +++ b/crates/burn-cubecl/src/fusion/base.rs @@ -3,7 +3,7 @@ use super::matmul::optimization::{MatmulOptimization, MatmulOptimizationState}; use crate::fusion::elemwise::builder::ElementWiseBuilder; use crate::fusion::matmul::builder::MatmulBuilder; use crate::BoolElement; -use crate::{kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{kernel, tensor::CubeTensor, CubeBackend, CubeRuntime, FloatElement, IntElement}; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_ir::{BackendIr, TensorHandle}; @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize}; /// Fusion optimization type for JIT. /// /// More optimization variants should be added here. -pub enum JitOptimization { +pub enum CubeOptimization { /// Element wise optimization. ElementWise(ElemwiseOptimization), /// Matrix multiplication optimization. @@ -28,19 +28,19 @@ pub enum JitOptimization { /// /// More optimization variants should be added here. #[derive(Serialize, Deserialize)] -pub enum JitOptimizationState { +pub enum CubeOptimizationState { /// Element wise state. ElementWise(ElemwiseOptimizationState), /// Matrix multiplication optimization state. Matmul(MatmulOptimizationState), } -impl burn_fusion::Optimization> for JitOptimization +impl burn_fusion::Optimization> for CubeOptimization where - R: JitRuntime, + R: CubeRuntime, BT: BoolElement, { - fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle>) { + fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle>) { match self { Self::ElementWise(op) => op.execute::(context), Self::Matmul(op) => op.execute::(context), @@ -54,29 +54,29 @@ where } } - fn to_state(&self) -> JitOptimizationState { + fn to_state(&self) -> CubeOptimizationState { match self { - Self::ElementWise(value) => JitOptimizationState::ElementWise(value.to_state()), - Self::Matmul(value) => JitOptimizationState::Matmul(value.to_state()), + Self::ElementWise(value) => CubeOptimizationState::ElementWise(value.to_state()), + Self::Matmul(value) => CubeOptimizationState::Matmul(value.to_state()), } } - fn from_state(device: &R::Device, state: JitOptimizationState) -> Self { + fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self { match state { - JitOptimizationState::ElementWise(state) => { + CubeOptimizationState::ElementWise(state) => { Self::ElementWise(ElemwiseOptimization::from_state(device, state)) } - JitOptimizationState::Matmul(state) => { + CubeOptimizationState::Matmul(state) => { Self::Matmul(MatmulOptimization::from_state(device, state)) } } } } -impl BackendIr - for JitBackend +impl BackendIr + for CubeBackend { - type Handle = JitFusionHandle; + type Handle = CubeFusionHandle; fn float_tensor(handle: TensorHandle) -> burn_tensor::ops::FloatTensor { handle.handle.into_tensor(handle.shape) @@ -113,11 +113,11 @@ impl BackendIr } } -impl FusionRuntime for FusionJitRuntime { - type OptimizationState = JitOptimizationState; - type Optimization = JitOptimization; - type FusionHandle = JitFusionHandle; - type FusionDevice = R::JitDevice; +impl FusionRuntime for FusionCubeRuntime { + type OptimizationState = CubeOptimizationState; + type Optimization = CubeOptimization; + type FusionHandle = CubeFusionHandle; + type FusionDevice = R::CubeDevice; type FusionClient = MutexFusionClient; type BoolRepr = BT; @@ -139,26 +139,26 @@ impl FusionRuntime for FusionJitRuntime { /// Fusion runtime for JIT runtimes. #[derive(Debug)] -pub struct FusionJitRuntime { +pub struct FusionCubeRuntime { _b: PhantomData, _bool: PhantomData, } -impl FusionBackend - for JitBackend +impl FusionBackend + for CubeBackend { - type FusionRuntime = FusionJitRuntime; + type FusionRuntime = FusionCubeRuntime; - type FullPrecisionBackend = JitBackend; + type FullPrecisionBackend = CubeBackend; fn cast_float( tensor: burn_tensor::ops::FloatTensor, dtype: burn_tensor::DType, ) -> Self::Handle { - fn cast( - tensor: JitTensor, - ) -> JitFusionHandle { - JitFusionHandle::from(kernel::cast::(tensor)) + fn cast( + tensor: CubeTensor, + ) -> CubeFusionHandle { + CubeFusionHandle::from(kernel::cast::(tensor)) } match dtype { @@ -183,7 +183,7 @@ pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Vec { } /// Handle to be used when fusing operations. -pub struct JitFusionHandle { +pub struct CubeFusionHandle { /// Compute client for jit. pub client: ComputeClient, /// The buffer where the data are stored. @@ -194,17 +194,17 @@ pub struct JitFusionHandle { pub(crate) strides: Vec, } -impl core::fmt::Debug for JitFusionHandle { +impl core::fmt::Debug for CubeFusionHandle { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( - "JitFusionHandle {{ device: {:?}, runtime: {}}}", + "CubeFusionHandle {{ device: {:?}, runtime: {}}}", self.device, R::name(), )) } } -impl Clone for JitFusionHandle { +impl Clone for CubeFusionHandle { fn clone(&self) -> Self { Self { client: self.client.clone(), @@ -216,12 +216,12 @@ impl Clone for JitFusionHandle { } } -unsafe impl Send for JitFusionHandle {} -unsafe impl Sync for JitFusionHandle {} +unsafe impl Send for CubeFusionHandle {} +unsafe impl Sync for CubeFusionHandle {} -impl JitFusionHandle { - pub(crate) fn into_tensor(self, shape: Shape) -> JitTensor { - JitTensor { +impl CubeFusionHandle { + pub(crate) fn into_tensor(self, shape: Shape) -> CubeTensor { + CubeTensor { client: self.client, handle: self.handle, device: self.device, @@ -256,8 +256,8 @@ impl JitFusionHandle { } } -impl From> for JitFusionHandle { - fn from(value: JitTensor) -> Self { +impl From> for CubeFusionHandle { + fn from(value: CubeTensor) -> Self { Self { client: value.client, handle: value.handle, diff --git a/crates/burn-cubecl/src/fusion/elemwise/builder.rs b/crates/burn-cubecl/src/fusion/elemwise/builder.rs index bda055231a..9fbff34bc4 100644 --- a/crates/burn-cubecl/src/fusion/elemwise/builder.rs +++ b/crates/burn-cubecl/src/fusion/elemwise/builder.rs @@ -3,20 +3,20 @@ use burn_fusion::OptimizationBuilder; use crate::{ fusion::{ on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings}, - JitOptimization, + CubeOptimization, }, - JitRuntime, + CubeRuntime, }; use super::optimization::ElemwiseOptimization; /// Fused element wise operations that are normally memory bound. -pub(crate) struct ElementWiseBuilder { +pub(crate) struct ElementWiseBuilder { builder: FuseOnWriteBuilder, device: R::Device, } -impl ElementWiseBuilder { +impl ElementWiseBuilder { pub fn new(device: R::Device, bool_precision: ElemwisePrecision) -> Self { let client = R::client(&device); let props = client.properties(); @@ -38,18 +38,18 @@ impl ElementWiseBuilder { } } -impl OptimizationBuilder> for ElementWiseBuilder { +impl OptimizationBuilder> for ElementWiseBuilder { fn register(&mut self, operation: &burn_ir::OperationIr) { self.builder.register(operation); } - fn build(&self) -> JitOptimization { + fn build(&self) -> CubeOptimization { let client = R::client(&self.device); let trace = self.builder.build(); let elementwise = ElemwiseOptimization::::new(trace, client, self.device.clone(), self.len()); - JitOptimization::ElementWise(elementwise) + CubeOptimization::ElementWise(elementwise) } fn reset(&mut self) { diff --git a/crates/burn-cubecl/src/fusion/elemwise/optimization.rs b/crates/burn-cubecl/src/fusion/elemwise/optimization.rs index 71b44b8e44..6578ad4ff2 100644 --- a/crates/burn-cubecl/src/fusion/elemwise/optimization.rs +++ b/crates/burn-cubecl/src/fusion/elemwise/optimization.rs @@ -1,5 +1,5 @@ use crate::{fusion::on_write::kernel::fuse_on_write, BoolElement}; -use crate::{fusion::JitFusionHandle, JitRuntime}; +use crate::{fusion::CubeFusionHandle, CubeRuntime}; use burn_fusion::stream::Context; use cubecl::{calculate_cube_count_elemwise, client::ComputeClient, prelude::*, CubeDim}; use serde::{Deserialize, Serialize}; @@ -11,7 +11,7 @@ use crate::fusion::on_write::{ #[derive(new)] /// Fuse element wise operations into a single kernel. -pub struct ElemwiseOptimization { +pub struct ElemwiseOptimization { trace: FuseOnWriteTrace, client: ComputeClient, device: R::Device, @@ -25,9 +25,9 @@ pub struct ElemwiseOptimizationState { len: usize, } -impl ElemwiseOptimization { +impl ElemwiseOptimization { /// Execute the optimization. - pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { + pub fn execute(&mut self, context: &mut Context<'_, CubeFusionHandle>) { self.trace .run::(&self.client, &self.device, context, &ElemwiseRunner) .unwrap(); @@ -59,7 +59,7 @@ impl ElemwiseOptimization { pub struct ElemwiseRunner; -impl TraceRunner for ElemwiseRunner { +impl TraceRunner for ElemwiseRunner { type Error = (); // No error possible fn run<'a>( diff --git a/crates/burn-cubecl/src/fusion/matmul/builder.rs b/crates/burn-cubecl/src/fusion/matmul/builder.rs index 3979a91b4f..59b7cecafe 100644 --- a/crates/burn-cubecl/src/fusion/matmul/builder.rs +++ b/crates/burn-cubecl/src/fusion/matmul/builder.rs @@ -4,22 +4,22 @@ use burn_ir::{FloatOperationIr, OperationIr}; use crate::{ fusion::{ on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings}, - JitOptimization, + CubeOptimization, }, - JitRuntime, + CubeRuntime, }; use super::optimization::{FusedMatmul, MatmulOptimization}; /// Fused element wise operations that are normally memory bound. -pub(crate) struct MatmulBuilder { +pub(crate) struct MatmulBuilder { builder: FuseOnWriteBuilder, builder_fallback: FuseOnWriteBuilder, device: R::Device, matmul: Option, } -impl MatmulBuilder { +impl MatmulBuilder { pub fn new(device: R::Device, bool_precision: ElemwisePrecision) -> Self { let client = R::client(&device); let props = client.properties(); @@ -40,7 +40,7 @@ impl MatmulBuilder { } } -impl OptimizationBuilder> for MatmulBuilder { +impl OptimizationBuilder> for MatmulBuilder { fn register(&mut self, operation: &OperationIr) { if let OptimizationStatus::Closed = self.builder.status() { return; @@ -74,7 +74,7 @@ impl OptimizationBuilder> for MatmulBuilder } } - fn build(&self) -> JitOptimization { + fn build(&self) -> CubeOptimization { let client = R::client(&self.device); let trace = self.builder.build(); let trace_fallback = self.builder_fallback.build(); @@ -88,7 +88,7 @@ impl OptimizationBuilder> for MatmulBuilder self.matmul.as_ref().unwrap().clone(), ); - JitOptimization::Matmul(matmul) + CubeOptimization::Matmul(matmul) } fn reset(&mut self) { diff --git a/crates/burn-cubecl/src/fusion/matmul/optimization.rs b/crates/burn-cubecl/src/fusion/matmul/optimization.rs index 9b22caac38..28581bf07c 100644 --- a/crates/burn-cubecl/src/fusion/matmul/optimization.rs +++ b/crates/burn-cubecl/src/fusion/matmul/optimization.rs @@ -3,7 +3,7 @@ use std::any::TypeId; use crate::fusion::elemwise::optimization::ElemwiseRunner; use crate::fusion::on_write::ir::ElemwisePrecision; use crate::kernel::matmul; -use crate::{fusion::JitFusionHandle, JitRuntime}; +use crate::{fusion::CubeFusionHandle, CubeRuntime}; use crate::{BoolElement, FloatElement}; use burn_fusion::stream::Context; @@ -31,7 +31,7 @@ use super::spec::FusedMatmulSpec; use super::tune::fused_matmul_autotune; /// Fuse matmul operation followed by elemwise operations into a single kernel. -pub struct MatmulOptimization { +pub struct MatmulOptimization { trace: FuseOnWriteTrace, trace_fallback: FuseOnWriteTrace, pub(crate) client: ComputeClient, @@ -53,7 +53,7 @@ pub struct MatmulOptimizationState { len: usize, } -impl MatmulOptimization { +impl MatmulOptimization { pub fn new( trace: FuseOnWriteTrace, trace_fallback: FuseOnWriteTrace, @@ -82,7 +82,7 @@ impl MatmulOptimization { } } /// Execute the optimization. - pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { + pub fn execute(&mut self, context: &mut Context<'_, CubeFusionHandle>) { #[cfg(feature = "autotune")] fused_matmul_autotune::(self, context); @@ -130,7 +130,7 @@ impl MatmulOptimization { pub fn execute_standard_fused( &self, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, ) -> Result<(), FusedMatmulError> { self.trace.run::( &self.client, @@ -142,7 +142,7 @@ impl MatmulOptimization { pub fn execute_specialized_fused( &self, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, ) -> Result<(), FusedMatmulError> { self.trace.run::( &self.client, @@ -154,7 +154,7 @@ impl MatmulOptimization { pub fn execute_pipelined_fused( &self, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, ) -> Result<(), FusedMatmulError> { self.trace.run::( &self.client, @@ -164,7 +164,10 @@ impl MatmulOptimization { ) } - pub fn execute_fallback(&self, context: &mut Context<'_, JitFusionHandle>) { + pub fn execute_fallback( + &self, + context: &mut Context<'_, CubeFusionHandle>, + ) { match self.matmul_standard.lhs.precision() { ElemwisePrecision::F32 => self.run_fallback::(context), ElemwisePrecision::F16 => self.run_fallback::(context), @@ -175,7 +178,7 @@ impl MatmulOptimization { fn run_fallback( &self, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, ) { let (out_tensor, out_desc) = { let lhs = context @@ -214,7 +217,7 @@ impl MatmulOptimization { }; context .handles - .register_handle(out_desc.id, JitFusionHandle::from(out_tensor)); + .register_handle(out_desc.id, CubeFusionHandle::from(out_tensor)); self.trace_fallback .run::(&self.client, &self.device, context, &ElemwiseRunner) @@ -251,7 +254,7 @@ impl From for FusedMatmulError { } } -impl TraceRunner for FusedMatmul { +impl TraceRunner for FusedMatmul { type Error = FusedMatmulError; fn run<'a>( @@ -273,7 +276,7 @@ impl TraceRunner for FusedMatmul { } impl FusedMatmul { - fn matmul_fused<'a, R: JitRuntime, EG: Numeric>( + fn matmul_fused<'a, R: CubeRuntime, EG: Numeric>( &'a self, client: &'a ComputeClient, inputs: GlobalArgsLaunch<'a, R>, diff --git a/crates/burn-cubecl/src/fusion/matmul/tune.rs b/crates/burn-cubecl/src/fusion/matmul/tune.rs index d98165fa2e..9d7eb11d8f 100644 --- a/crates/burn-cubecl/src/fusion/matmul/tune.rs +++ b/crates/burn-cubecl/src/fusion/matmul/tune.rs @@ -1,10 +1,10 @@ use crate::{ fusion::{ tune::{TuneContext, TuneInput}, - JitFusionHandle, + CubeFusionHandle, }, kernel::matmul::MatmulAutotuneKey, - BoolElement, JitRuntime, JitTuneId, + BoolElement, CubeRuntime, CubeTuneId, }; use burn_fusion::stream::Context; use cubecl::{ @@ -25,11 +25,11 @@ pub struct FusedMatmulAutotuneKey { } /// Executes autotune on matmul operations -pub fn fused_matmul_autotune( +pub fn fused_matmul_autotune( optimization: &MatmulOptimization, - context: &mut Context>, + context: &mut Context>, ) { - static TUNER: LocalTuner = local_tuner!(); + static TUNER: LocalTuner = local_tuner!(); let tunables = TunableSet::new(create_key::, input_gen::) .with_tunable(tune_standard_fused::) @@ -38,14 +38,14 @@ pub fn fused_matmul_autotune( .with_tunable(tune_fallback::); TUNER.execute( - &JitTuneId::new::(&optimization.device), + &CubeTuneId::new::(&optimization.device), &optimization.client, &tunables, TuneInput::new(context, optimization), ); } -pub(crate) fn create_key( +pub(crate) fn create_key( input: &TuneInput>, ) -> FusedMatmulAutotuneKey { let opt = input.optimization(); @@ -66,14 +66,14 @@ pub(crate) fn create_key( FusedMatmulAutotuneKey::new(key, opt.num_output_buffers(), opt.num_ops_fused()) } -fn input_gen( +fn input_gen( _key: &FusedMatmulAutotuneKey, input: &TuneInput>, ) -> TuneInput> { input.clone() } -fn tune_standard_fused( +fn tune_standard_fused( input: TuneInput>, ) -> Result<(), String> { let optimization = input.optimization(); @@ -88,7 +88,7 @@ fn tune_standard_fused( .map_err(|e| format!("{e:?}")) } -fn tune_specialized_fused( +fn tune_specialized_fused( input: TuneInput>, ) -> Result<(), String> { let optimization = input.optimization(); @@ -103,7 +103,7 @@ fn tune_specialized_fused( .map_err(|e| format!("{e:?}")) } -fn tune_pipelined_fused( +fn tune_pipelined_fused( input: TuneInput>, ) -> Result<(), String> { let optimization = input.optimization(); @@ -118,7 +118,7 @@ fn tune_pipelined_fused( .map_err(|e| format!("{e:?}")) } -fn tune_fallback( +fn tune_fallback( input: TuneInput>, ) -> Result<(), String> { let optimization = input.optimization(); diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/base.rs b/crates/burn-cubecl/src/fusion/on_write/trace/base.rs index a812467e8e..80ce63db23 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/base.rs +++ b/crates/burn-cubecl/src/fusion/on_write/trace/base.rs @@ -1,4 +1,4 @@ -use crate::{fusion::JitFusionHandle, BoolElement, JitRuntime}; +use crate::{fusion::CubeFusionHandle, BoolElement, CubeRuntime}; use super::{ super::{ @@ -48,11 +48,11 @@ pub enum TensorView { impl FuseOnWriteTrace { /// Run a trace with the given [runner](TraceRunner). - pub fn run>( + pub fn run>( &self, client: &ComputeClient, device: &R::Device, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, runner: &Runner, ) -> Result<(), Runner::Error> { let mut plan = LaunchPlan::new(&self.reads, &self.writes, self.shape_ref.len()); @@ -83,9 +83,9 @@ impl FuseOnWriteTrace { } } - fn rollback( + fn rollback( &self, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, handle_inputs: Vec>, handle_outputs: Vec>, ) { diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/executor.rs b/crates/burn-cubecl/src/fusion/on_write/trace/executor.rs index a44268764c..1b07fdcda6 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/executor.rs +++ b/crates/burn-cubecl/src/fusion/on_write/trace/executor.rs @@ -11,13 +11,13 @@ use super::{HandleInput, HandleOutput, LaunchPlan, TensorView, TraceRunner}; use crate::{ fusion::{ on_write::ir::{ElemwiseConfig, ElemwiseOp, ElemwisePrecision, GlobalArgsLaunch}, - JitFusionHandle, + CubeFusionHandle, }, - BoolElement, JitRuntime, + BoolElement, CubeRuntime, }; /// Execute a [plan](LaunchPlan) using a [runner](TraceRunner) modifying the [context](Context). -pub struct LaunchPlanExecutor<'a, R: JitRuntime> { +pub struct LaunchPlanExecutor<'a, R: CubeRuntime> { scalars: &'a BTreeMap, views: &'a Vec, ops: &'a Vec, @@ -25,13 +25,13 @@ pub struct LaunchPlanExecutor<'a, R: JitRuntime> { } #[derive(new)] -pub struct ExecutionError> { +pub struct ExecutionError> { pub runner_error: Runner::Error, pub handles_input: Vec>, pub handles_output: Vec>, } -impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { +impl<'a, R: CubeRuntime> LaunchPlanExecutor<'a, R> { pub fn new( scalars: &'a BTreeMap, views: &'a Vec, @@ -49,7 +49,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { self, client: &ComputeClient, runner: &Runner, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, plan: LaunchPlan<'a, R>, ) -> Result<(), ExecutionError> { let reference = match plan.reference { @@ -95,7 +95,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { fn register_inputs<'h>( &self, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, handle_inputs: &'h [HandleInput], ) -> GlobalArgsLaunch<'h, R> { let mut inputs = GlobalArgsLaunch::default(); diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/input.rs b/crates/burn-cubecl/src/fusion/on_write/trace/input.rs index 7633787d64..0a3b4aca2f 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/input.rs +++ b/crates/burn-cubecl/src/fusion/on_write/trace/input.rs @@ -1,7 +1,7 @@ use super::TensorView; use crate::{ - fusion::{on_write::settings::FuseSettings, JitFusionHandle}, - JitRuntime, + fusion::{on_write::settings::FuseSettings, CubeFusionHandle}, + CubeRuntime, }; use burn_fusion::stream::Context; use burn_ir::{TensorId, TensorStatus}; @@ -11,7 +11,7 @@ use super::{HandleInput, LaunchPlan, PotentialInplace, RegisteredTensors}; /// Fetch and register [input handles](HandleInput) and itendify potential inputs that /// can be used inplace. -pub struct InputPlanner<'a, R: JitRuntime> { +pub struct InputPlanner<'a, R: CubeRuntime> { inputs: &'a RegisteredTensors, inputs_unhandled: &'a Vec, views: &'a Vec, @@ -20,7 +20,7 @@ pub struct InputPlanner<'a, R: JitRuntime> { _r: PhantomData, } -impl<'a, R: JitRuntime> InputPlanner<'a, R> { +impl<'a, R: CubeRuntime> InputPlanner<'a, R> { pub fn new( inputs: &'a RegisteredTensors, inputs_unhandled: &'a Vec, @@ -38,7 +38,7 @@ impl<'a, R: JitRuntime> InputPlanner<'a, R> { } } - pub fn run(self, context: &mut Context<'_, JitFusionHandle>, plan: &mut LaunchPlan<'a, R>) { + pub fn run(self, context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>) { for (i, (precision, tensor_relative)) in self.inputs.iter().enumerate() { let mut tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); // Important to take the status of the relative graph and not diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/output.rs b/crates/burn-cubecl/src/fusion/on_write/trace/output.rs index 34a82fda2e..03bc5cff1c 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/output.rs +++ b/crates/burn-cubecl/src/fusion/on_write/trace/output.rs @@ -6,10 +6,10 @@ use cubecl::{client::ComputeClient, ir::Elem}; use crate::{ fusion::{ on_write::ir::{Arg, ElemwiseOp, LayoutInfo}, - strides_dyn_rank, JitFusionHandle, + strides_dyn_rank, CubeFusionHandle, }, tensor::is_contiguous, - BoolElement, JitRuntime, + BoolElement, CubeRuntime, }; use super::{ @@ -21,7 +21,7 @@ use std::collections::BTreeMap; /// Create or reuse handles for the outputs. /// /// It is also responsible to select the reference tensor. -pub struct OutputPlanner<'a, R: JitRuntime> { +pub struct OutputPlanner<'a, R: CubeRuntime> { inputs: &'a RegisteredTensors, views: &'a Vec, outputs_sorted: Vec>, @@ -42,7 +42,7 @@ enum OutputKind { Transform(TensorView), } -impl<'a, R: JitRuntime> OutputPlanner<'a, R> { +impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { pub fn new( inputs: &'a RegisteredTensors, outputs: &'a RegisteredTensors, @@ -91,7 +91,7 @@ impl<'a, R: JitRuntime> OutputPlanner<'a, R> { mut self, client: &ComputeClient, device: &R::Device, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, ) { // So that we can borrow self during the iteration. @@ -201,7 +201,7 @@ impl<'a, R: JitRuntime> OutputPlanner<'a, R> { fn inplace_output( &mut self, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorIr, @@ -251,7 +251,7 @@ impl<'a, R: JitRuntime> OutputPlanner<'a, R> { &mut self, client: &ComputeClient, device: &R::Device, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorIr, @@ -288,7 +288,7 @@ impl<'a, R: JitRuntime> OutputPlanner<'a, R> { }; let size = tensor_global.shape.iter().product::() * Elem::from(dtype).size(); - let handle = JitFusionHandle { + let handle = CubeFusionHandle { client: client.clone(), handle: client.empty(size), device: device.clone(), @@ -316,7 +316,7 @@ impl<'a, R: JitRuntime> OutputPlanner<'a, R> { &mut self, client: &ComputeClient, device: &R::Device, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorIr, @@ -342,7 +342,7 @@ impl<'a, R: JitRuntime> OutputPlanner<'a, R> { ) { plan.writes.remove(&output.tensor_relative.id); - let handle = JitFusionHandle { + let handle = CubeFusionHandle { client: client.clone(), handle: original_handle.handle.handle.clone(), device: device.clone(), @@ -376,7 +376,7 @@ impl<'a, R: JitRuntime> OutputPlanner<'a, R> { &mut self, client: &ComputeClient, device: &R::Device, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorIr, @@ -399,7 +399,7 @@ impl<'a, R: JitRuntime> OutputPlanner<'a, R> { plan.writes.remove(&output.tensor_relative.id); let strides = original_handle.handle.strides.clone(); - let mut handle = JitFusionHandle { + let mut handle = CubeFusionHandle { client: client.clone(), handle: original_handle.handle.handle.clone(), device: device.clone(), diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/plan.rs b/crates/burn-cubecl/src/fusion/on_write/trace/plan.rs index 5abe13ab21..5a25730963 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/plan.rs +++ b/crates/burn-cubecl/src/fusion/on_write/trace/plan.rs @@ -3,16 +3,16 @@ use std::collections::BTreeMap; use crate::{ fusion::{ on_write::ir::{Arg, ElemwiseOp, ElemwisePrecision}, - JitFusionHandle, + CubeFusionHandle, }, - JitRuntime, + CubeRuntime, }; use burn_ir::{TensorId, TensorIr}; /// The plan is responsible to keep runtime information related to the launch of a fused kernel /// at one place. #[derive(Debug)] -pub(crate) struct LaunchPlan<'a, R: JitRuntime> { +pub(crate) struct LaunchPlan<'a, R: CubeRuntime> { pub potential_inplaces: Vec>, pub global_inputs: Vec, pub global_outputs: Vec, @@ -25,7 +25,7 @@ pub(crate) struct LaunchPlan<'a, R: JitRuntime> { pub rank: usize, } -impl LaunchPlan<'_, R> { +impl LaunchPlan<'_, R> { pub fn new( reads: &BTreeMap>, writes: &BTreeMap, @@ -47,7 +47,7 @@ impl LaunchPlan<'_, R> { } #[derive(Debug)] -pub enum HandleOutput { +pub enum HandleOutput { Alias { input_pos: usize, precision: ElemwisePrecision, @@ -55,18 +55,18 @@ pub enum HandleOutput { Owned { global_id: TensorId, precision: ElemwisePrecision, - handle: JitFusionHandle, + handle: CubeFusionHandle, global_shape: Vec, vectorization: u8, }, } #[derive(Debug)] -pub struct HandleInput { +pub struct HandleInput { pub relative_id: TensorId, pub global_id: TensorId, pub precision: ElemwisePrecision, - pub handle: JitFusionHandle, + pub handle: CubeFusionHandle, pub global_shape: Vec, pub vectorization: u8, } diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/runner.rs b/crates/burn-cubecl/src/fusion/on_write/trace/runner.rs index b76e121596..1bb67a3541 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/runner.rs +++ b/crates/burn-cubecl/src/fusion/on_write/trace/runner.rs @@ -1,5 +1,5 @@ use super::super::ir::{ElemwiseConfig, GlobalArgsLaunch}; -use crate::{fusion::JitFusionHandle, JitRuntime}; +use crate::{fusion::CubeFusionHandle, CubeRuntime}; use burn_ir::{TensorId, TensorIr}; use cubecl::prelude::*; use std::collections::BTreeMap; @@ -7,7 +7,7 @@ use std::collections::BTreeMap; /// A trace runner is responsible for determining the vectorization factor as well as launching /// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch) /// with a provided [element wise config](ElemwiseConfig). -pub trait TraceRunner { +pub trait TraceRunner { /// The error that might happen while running the trace. type Error; @@ -23,7 +23,7 @@ pub trait TraceRunner { /// The vectorization factor for all inputs and outputs. fn vectorization<'a>( vectorizations: &mut BTreeMap, - handles_inputs: impl Iterator>, + handles_inputs: impl Iterator>, inputs: impl Iterator, outputs: impl Iterator, reshaped: impl Iterator, @@ -40,9 +40,9 @@ pub trait TraceRunner { } } -fn vectorization_default<'a, R: JitRuntime>( +fn vectorization_default<'a, R: CubeRuntime>( vectorizations: &mut BTreeMap, - handles_inputs: impl Iterator>, + handles_inputs: impl Iterator>, inputs: impl Iterator, outputs: impl Iterator, reshaped: impl Iterator, @@ -57,7 +57,7 @@ fn vectorization_default<'a, R: JitRuntime>( // The default version uses the last dimension as vectorization axis and assumes a // perpendicular contiguous line. - let vectorization_input = |handle: &JitFusionHandle, desc: &TensorIr| { + let vectorization_input = |handle: &CubeFusionHandle, desc: &TensorIr| { let rank = handle.strides.len(); // Last dimension strides should be 1, otherwise vecX won't be contiguous. @@ -120,7 +120,7 @@ fn vectorization_default<'a, R: JitRuntime>( Vect::Max(1) }; - let vectorization_swapped = |handle: &JitFusionHandle, + let vectorization_swapped = |handle: &CubeFusionHandle, swapped: &TensorIr, original: &TensorIr, multi_reads: bool, diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/vectorization.rs b/crates/burn-cubecl/src/fusion/on_write/trace/vectorization.rs index 70b92c641f..82361ff8ee 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/vectorization.rs +++ b/crates/burn-cubecl/src/fusion/on_write/trace/vectorization.rs @@ -9,15 +9,15 @@ use burn_ir::TensorId; use crate::{ fusion::{ on_write::{ir::ElemwiseOp, settings::FuseSettings}, - JitFusionHandle, + CubeFusionHandle, }, - JitRuntime, + CubeRuntime, }; use super::{HandleOutput, LaunchPlan, TensorView, TraceRunner}; /// Select the best vectorization factor for each tensor handle. -pub struct VectorizationPlanner<'a, R: JitRuntime> { +pub struct VectorizationPlanner<'a, R: CubeRuntime> { settings: &'a FuseSettings, views: &'a Vec, reads: &'a BTreeMap>, @@ -25,7 +25,7 @@ pub struct VectorizationPlanner<'a, R: JitRuntime> { _r: PhantomData, } -impl<'a, R: JitRuntime> VectorizationPlanner<'a, R> { +impl<'a, R: CubeRuntime> VectorizationPlanner<'a, R> { pub fn new( views: &'a Vec, reads: &'a BTreeMap>, @@ -42,7 +42,7 @@ impl<'a, R: JitRuntime> VectorizationPlanner<'a, R> { } pub fn run>( self, - context: &mut Context<'_, JitFusionHandle>, + context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, ) { let tensors_reshaped = self.views.iter().filter_map(|view| match view { diff --git a/crates/burn-cubecl/src/fusion/tune.rs b/crates/burn-cubecl/src/fusion/tune.rs index 8c45f93bb0..c6c9cbab4a 100644 --- a/crates/burn-cubecl/src/fusion/tune.rs +++ b/crates/burn-cubecl/src/fusion/tune.rs @@ -1,5 +1,5 @@ -use super::JitFusionHandle; -use crate::JitRuntime; +use super::CubeFusionHandle; +use crate::CubeRuntime; use burn_fusion::stream::{Context, ContextOwned}; /// Fusion context used when tuning kernels. @@ -7,9 +7,9 @@ use burn_fusion::stream::{Context, ContextOwned}; /// Either the original context is returned or a fork of the original. /// The fork is only given when performing autotuning, and not when actually performing the /// operation. -pub enum TuneContext<'a, R: JitRuntime> { - Original(&'a mut Context<'a, JitFusionHandle>), - Fork(Box>>), +pub enum TuneContext<'a, R: CubeRuntime> { + Original(&'a mut Context<'a, CubeFusionHandle>), + Fork(Box>>), } /// Fusion input wrapper containing the context and the optimization. @@ -18,7 +18,7 @@ pub enum TuneContext<'a, R: JitRuntime> { /// /// This should only be used with the [tuner](cubecl::tune::LocalTuner), since safety assumptions /// are made based on its behavior. -pub struct TuneInput { +pub struct TuneInput { context: UnsafeTuneContext, optimization: *const O, } @@ -33,17 +33,17 @@ pub struct TuneInput { /// [cubecl::tune::LocalTuner::execute] function. This is the case, since autotune functions are /// tuned using a cloned version of the input; therefore, a fork of the context will be used to find /// the best kernel to use, which can be async. -enum UnsafeTuneContext { - Original(*mut Context<'static, JitFusionHandle>), - Fork(Box>>), +enum UnsafeTuneContext { + Original(*mut Context<'static, CubeFusionHandle>), + Fork(Box>>), } -unsafe impl Send for UnsafeTuneContext {} -unsafe impl Send for TuneInput {} +unsafe impl Send for UnsafeTuneContext {} +unsafe impl Send for TuneInput {} -impl TuneInput { +impl TuneInput { /// Create a new autotune input from the [context](Context) and an optimization. - pub fn new(context: &mut Context>, optimization: &O) -> Self { + pub fn new(context: &mut Context>, optimization: &O) -> Self { let context = UnsafeTuneContext::new(context); // We can erase the lifetime for the same reason we do with the context. let optimization = core::ptr::from_ref(optimization); @@ -65,8 +65,8 @@ impl TuneInput { } } -impl UnsafeTuneContext { - fn new(context: &mut Context<'_, JitFusionHandle>) -> Self { +impl UnsafeTuneContext { + fn new(context: &mut Context<'_, CubeFusionHandle>) -> Self { let ptr = core::ptr::from_mut(context); // It is necessary for the lifetime. @@ -84,7 +84,7 @@ impl UnsafeTuneContext { } } -impl Clone for TuneInput { +impl Clone for TuneInput { fn clone(&self) -> Self { Self { context: self.context.clone(), @@ -93,11 +93,11 @@ impl Clone for TuneInput { } } -impl Clone for UnsafeTuneContext { +impl Clone for UnsafeTuneContext { fn clone(&self) -> Self { let context = match self { UnsafeTuneContext::Original(ptr) => { - let context: &mut Context<'static, JitFusionHandle> = + let context: &mut Context<'static, CubeFusionHandle> = unsafe { ptr.as_mut().unwrap() }; context.fork() } diff --git a/crates/burn-cubecl/src/kernel/binary.rs b/crates/burn-cubecl/src/kernel/binary.rs index f0da764a7a..98222a5de2 100644 --- a/crates/burn-cubecl/src/kernel/binary.rs +++ b/crates/burn-cubecl/src/kernel/binary.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use burn_tensor::Shape; use cubecl::{ calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, @@ -160,10 +160,10 @@ pub(crate) fn kernel_binop( out[offset_out] = O::BinaryOp::::execute(lhs[offset_lhs], rhs[offset_rhs]); } -pub(crate) fn launch_binop( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub(crate) fn launch_binop( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { let ndims = lhs.shape.num_dims(); let line_size_lhs = tensor_line_size_parallel( R::line_size_elem(&E::as_elem_native_unchecked()), @@ -247,10 +247,10 @@ pub(crate) fn launch_binop( } } -pub(crate) fn launch_scalar_binop( - mut tensor: JitTensor, +pub(crate) fn launch_scalar_binop( + mut tensor: CubeTensor, scalar: E, -) -> JitTensor { +) -> CubeTensor { if !tensor.is_contiguous_buffer() { tensor = into_contiguous(tensor); } diff --git a/crates/burn-cubecl/src/kernel/binary_int.rs b/crates/burn-cubecl/src/kernel/binary_int.rs index 390bfc479e..e566e837bc 100644 --- a/crates/burn-cubecl/src/kernel/binary_int.rs +++ b/crates/burn-cubecl/src/kernel/binary_int.rs @@ -1,4 +1,4 @@ -use crate::{ops::numeric::empty_device, tensor::JitTensor, IntElement, JitRuntime}; +use crate::{ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime, IntElement}; use burn_tensor::Shape; use cubecl::{ calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, @@ -133,10 +133,10 @@ pub(crate) fn kernel_binop_int( out[offset_out] = O::BinaryOp::::execute(lhs[offset_lhs], rhs[offset_rhs]); } -pub(crate) fn launch_binop_int( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub(crate) fn launch_binop_int( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { let ndims = lhs.shape.num_dims(); let line_size_lhs = tensor_line_size_parallel( R::line_size_elem(&E::as_elem_native_unchecked()), @@ -220,10 +220,10 @@ pub(crate) fn launch_binop_int( - mut tensor: JitTensor, +pub(crate) fn launch_scalar_binop_int( + mut tensor: CubeTensor, scalar: E, -) -> JitTensor { +) -> CubeTensor { if !tensor.is_contiguous_buffer() { tensor = into_contiguous(tensor); } diff --git a/crates/burn-cubecl/src/kernel/cast/base.rs b/crates/burn-cubecl/src/kernel/cast/base.rs index 43b24f071a..0aade1e4b8 100644 --- a/crates/burn-cubecl/src/kernel/cast/base.rs +++ b/crates/burn-cubecl/src/kernel/cast/base.rs @@ -1,4 +1,4 @@ -use crate::{tensor::JitTensor, JitElement, JitRuntime}; +use crate::{tensor::CubeTensor, CubeElement, CubeRuntime}; use cubecl::linalg::tensor::index_offset_with_layout; use cubecl::{calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor}; use std::any::TypeId; @@ -30,9 +30,11 @@ pub(crate) fn cast_element( /// Cast a tensor to the given element type. /// /// Note: When input element is semantically a boolean, prefer bool_cast function. -pub fn cast(input: JitTensor) -> JitTensor { +pub fn cast( + input: CubeTensor, +) -> CubeTensor { if TypeId::of::() == TypeId::of::() { - return JitTensor::new_contiguous( + return CubeTensor::new_contiguous( input.client, input.device, input.shape, @@ -53,7 +55,7 @@ pub fn cast(input: JitTensor) calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); let client = input.client.clone(); let handle = client.empty(num_elems * core::mem::size_of::()); - let output = JitTensor::new_contiguous( + let output = CubeTensor::new_contiguous( client.clone(), input.device.clone(), input.shape.clone(), diff --git a/crates/burn-cubecl/src/kernel/cast/bool_cast.rs b/crates/burn-cubecl/src/kernel/cast/bool_cast.rs index 74e55888e1..ebb3b89a44 100644 --- a/crates/burn-cubecl/src/kernel/cast/bool_cast.rs +++ b/crates/burn-cubecl/src/kernel/cast/bool_cast.rs @@ -1,4 +1,4 @@ -use crate::{tensor::JitTensor, BoolElement, JitElement, JitRuntime}; +use crate::{tensor::CubeTensor, BoolElement, CubeElement, CubeRuntime}; use cubecl::{calculate_cube_count_elemwise, prelude::*, CubeDim}; #[cube(launch)] @@ -16,12 +16,12 @@ fn bool_cast_kernel(input: &Tensor, output: &mut Tens /// where any non-zero value means true. Depending how it was created /// it may hold an uncanny bit combination. Naively casting it would not /// necessarily yield 0 or 1. -pub fn bool_cast( - tensor: JitTensor, -) -> JitTensor { +pub fn bool_cast( + tensor: CubeTensor, +) -> CubeTensor { let num_elems = tensor.shape.num_elements(); let buffer = tensor.client.empty(num_elems * core::mem::size_of::()); - let output = JitTensor::new_contiguous( + let output = CubeTensor::new_contiguous( tensor.client.clone(), tensor.device.clone(), tensor.shape.clone(), diff --git a/crates/burn-cubecl/src/kernel/clamp.rs b/crates/burn-cubecl/src/kernel/clamp.rs index ec2bc93d1f..626487cdfc 100644 --- a/crates/burn-cubecl/src/kernel/clamp.rs +++ b/crates/burn-cubecl/src/kernel/clamp.rs @@ -1,10 +1,10 @@ use cubecl::prelude::*; use crate::{ - element::JitElement, + element::CubeElement, kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}, - tensor::JitTensor, - JitRuntime, + tensor::CubeTensor, + CubeRuntime, }; #[derive(CubeLaunch)] @@ -13,11 +13,11 @@ struct Options { max_value: C, } -pub(crate) fn clamp( - input: JitTensor, +pub(crate) fn clamp( + input: CubeTensor, min_value: E, max_value: E, -) -> JitTensor { +) -> CubeTensor { struct ClampOp; #[cube] diff --git a/crates/burn-cubecl/src/kernel/comparison.rs b/crates/burn-cubecl/src/kernel/comparison.rs index a6de9025bb..8354e5bf46 100644 --- a/crates/burn-cubecl/src/kernel/comparison.rs +++ b/crates/burn-cubecl/src/kernel/comparison.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use crate::{ - element::JitElement, ops::numeric::empty_device, tensor::JitTensor, BoolElement, JitRuntime, + element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, BoolElement, CubeRuntime, }; use burn_tensor::Shape; use cubecl::{ @@ -130,10 +130,10 @@ pub(crate) fn kernel_cmp>( out[offset_out] = Line::cast_from(O::execute(lhs[offset_lhs], rhs[offset_rhs])); } -pub(crate) fn launch_cmp>( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub(crate) fn launch_cmp>( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { let ndims = lhs.shape.num_dims(); let vectorization_factor_lhs = tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, ndims - 1); @@ -174,7 +174,7 @@ pub(crate) fn launch_cmp, >( - mut tensor: JitTensor, + mut tensor: CubeTensor, scalar: E, -) -> JitTensor { +) -> CubeTensor { if !tensor.is_contiguous_buffer() { tensor = into_contiguous(tensor); } @@ -259,7 +259,7 @@ pub(crate) fn launch_scalar_cmp< TensorArg::alias(0), ); - JitTensor::new( + CubeTensor::new( tensor.client, tensor.handle, tensor.shape, @@ -287,72 +287,72 @@ pub(crate) fn launch_scalar_cmp< } } -pub fn equal( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn equal( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_cmp::(lhs, rhs) } -pub fn greater( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn greater( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_cmp::(lhs, rhs) } -pub fn greater_equal( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn greater_equal( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_cmp::(lhs, rhs) } -pub fn lower( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn lower( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_cmp::(lhs, rhs) } -pub fn lower_equal( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn lower_equal( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_cmp::(lhs, rhs) } -pub fn equal_elem( - lhs: JitTensor, +pub fn equal_elem( + lhs: CubeTensor, rhs: E, -) -> JitTensor { +) -> CubeTensor { launch_scalar_cmp::(lhs, rhs) } -pub fn greater_elem( - lhs: JitTensor, +pub fn greater_elem( + lhs: CubeTensor, rhs: E, -) -> JitTensor { +) -> CubeTensor { launch_scalar_cmp::(lhs, rhs) } -pub fn lower_elem( - lhs: JitTensor, +pub fn lower_elem( + lhs: CubeTensor, rhs: E, -) -> JitTensor { +) -> CubeTensor { launch_scalar_cmp::(lhs, rhs) } -pub fn greater_equal_elem( - lhs: JitTensor, +pub fn greater_equal_elem( + lhs: CubeTensor, rhs: E, -) -> JitTensor { +) -> CubeTensor { launch_scalar_cmp::(lhs, rhs) } -pub fn lower_equal_elem( - lhs: JitTensor, +pub fn lower_equal_elem( + lhs: CubeTensor, rhs: E, -) -> JitTensor { +) -> CubeTensor { launch_scalar_cmp::(lhs, rhs) } diff --git a/crates/burn-cubecl/src/kernel/contiguous.rs b/crates/burn-cubecl/src/kernel/contiguous.rs index b21d032c78..c88a18bc0e 100644 --- a/crates/burn-cubecl/src/kernel/contiguous.rs +++ b/crates/burn-cubecl/src/kernel/contiguous.rs @@ -1,7 +1,7 @@ -use crate::{execute_with_dtype, tensor::JitTensor, JitRuntime}; +use crate::{execute_with_dtype, tensor::CubeTensor, CubeRuntime}; /// Make a jit tensor contiguous. -pub fn into_contiguous(tensor: JitTensor) -> JitTensor { +pub fn into_contiguous(tensor: CubeTensor) -> CubeTensor { if tensor.is_contiguous() { return tensor; } @@ -12,7 +12,7 @@ pub fn into_contiguous(tensor: JitTensor) -> JitTensor { &tensor.as_handle_ref(), ); - JitTensor::new( + CubeTensor::new( tensor.client, output.handle, output.shape.into(), diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/base.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/base.rs index f015677a2b..a39524fa7d 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/base.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/base.rs @@ -1,7 +1,7 @@ use burn_tensor::ops::{ConvOptions, ConvTransposeOptions}; use crate::{ - kernel::conv::ConvLaunchError, tensor::JitTensor, FloatElement, IntElement, JitRuntime, + kernel::conv::ConvLaunchError, tensor::CubeTensor, CubeRuntime, FloatElement, IntElement, }; #[cfg(feature = "autotune")] @@ -71,13 +71,13 @@ impl Default for ConvTranspose2dStrategy { /// * `options` - The options to use for the convolution /// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. /// -pub fn conv2d( - input: JitTensor, - weight: JitTensor, - bias: Option>, +pub fn conv2d( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvOptions<2>, strategy: Conv2dStrategy, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { match strategy { Conv2dStrategy::Direct => conv2d_direct::(input, weight, bias, options), #[cfg(feature = "autotune")] @@ -98,13 +98,13 @@ pub fn conv2d( /// * `options` - The options to use for the convolution /// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. /// -pub fn conv_transpose2d( - input: JitTensor, - weight: JitTensor, - bias: Option>, +pub fn conv_transpose2d( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvTransposeOptions<2>, strategy: ConvTranspose2dStrategy, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { match strategy { ConvTranspose2dStrategy::Direct => { conv_transpose2d_direct::(input, weight, bias, options) diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/col2im.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/col2im.rs index 4f6931f86d..b19763fcb5 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/col2im.rs @@ -12,8 +12,8 @@ use crate::{ slice, }, ops::{numeric::empty_device, reshape, swap_dims}, - tensor::JitTensor, - FloatElement, JitElement, JitRuntime, + tensor::CubeTensor, + CubeElement, CubeRuntime, FloatElement, }; use super::batches_per_run; @@ -25,12 +25,12 @@ use super::batches_per_run; /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -pub fn conv_transpose2d_col2im( - input: JitTensor, - weight: JitTensor, - bias: Option>, +pub fn conv_transpose2d_col2im( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvTransposeOptions<2>, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.shape.dims(); let [batch_size, _, input_h, input_w] = input.shape.dims(); let groups = options.groups; @@ -117,7 +117,10 @@ pub fn conv_transpose2d_col2im( } } -pub(crate) fn index(tensor: JitTensor, i: usize) -> JitTensor { +pub(crate) fn index( + tensor: CubeTensor, + i: usize, +) -> CubeTensor { #[allow(clippy::single_range_in_vec_init)] let mut indices = vec![i..i + 1]; for dim in tensor.shape.dims[1..].iter() { @@ -131,11 +134,11 @@ pub(crate) fn index(tensor: JitTensor, i: usize } #[allow(clippy::too_many_arguments)] -fn execute( - input: JitTensor, - weight: JitTensor, - bias: Option>, - image: JitTensor, +fn execute( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, + image: CubeTensor, options: ConvTransposeOptions<2>, kernel_h: usize, kernel_w: usize, @@ -160,10 +163,10 @@ fn execute( } #[allow(clippy::too_many_arguments)] -fn col2im( - columns: JitTensor, - bias: Option>, - out: JitTensor, +fn col2im( + columns: CubeTensor, + bias: Option>, + out: CubeTensor, kernel_h: usize, kernel_w: usize, out_h: usize, diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/direct.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/direct.rs index 1cd24f7c0c..52779b736f 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/direct.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/direct.rs @@ -10,8 +10,8 @@ use crate::{ numeric::{empty_device, zeros_device}, reshape, }, - tensor::JitTensor, - FloatElement, JitRuntime, + tensor::CubeTensor, + CubeRuntime, FloatElement, }; #[derive(CubeLaunch)] @@ -120,12 +120,12 @@ fn direct_conv2d_kernel( /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -pub fn conv2d_direct( - input: JitTensor, - weight: JitTensor, - bias: Option>, +pub fn conv2d_direct( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvOptions<2>, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { let [batch_size, _, in_height, in_width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let channels_per_group = out_channels / options.groups; diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/launch.rs index ad70a9b825..f4a787f1dd 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/launch.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/launch.rs @@ -28,8 +28,8 @@ use crate::{ into_contiguous, }, ops::{numeric::empty_device, permute, reshape}, - tensor::JitTensor, - FloatElement, JitElement, JitRuntime, + tensor::CubeTensor, + CubeElement, CubeRuntime, FloatElement, }; /// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul @@ -39,12 +39,12 @@ use crate::{ /// * `weight` - The weights (filter) applied to each kernel /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution -pub fn conv2d_gemm_cmma_large_m( - input: JitTensor, - weight: JitTensor, - bias: Option>, +pub fn conv2d_gemm_cmma_large_m( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvOptions<2>, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { conv2d_gemm_cmma_strategy::(input, weight, bias, options) } @@ -55,26 +55,26 @@ pub fn conv2d_gemm_cmma_large_m( /// * `weight` - The weights (filter) applied to each kernel /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution -pub fn conv2d_gemm_cmma_balanced( - input: JitTensor, - weight: JitTensor, - bias: Option>, +pub fn conv2d_gemm_cmma_balanced( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvOptions<2>, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { conv2d_gemm_cmma_strategy::(input, weight, bias, options) } fn conv2d_gemm_cmma_strategy< - R: JitRuntime, + R: CubeRuntime, F: FloatElement, Alg: Algorithm, S: ConvSelector, >( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvOptions<2>, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { if TypeId::of::() == TypeId::of::() { conv2d_gemm_with_algo::(input, weight, bias, options) } else if TypeId::of::() == TypeId::of::() || TypeId::of::() == TypeId::of::() @@ -95,18 +95,18 @@ fn conv2d_gemm_cmma_strategy< /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution pub fn conv2d_gemm_with_algo< - R: JitRuntime, + R: CubeRuntime, SP: ConvPrecision, Alg: Algorithm, S: ConvSelector, >( - input: JitTensor, - weight: JitTensor, - bias: Option>, + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvOptions<2>, -) -> Result, ConvLaunchError> +) -> Result, ConvLaunchError> where - SP::EG: JitElement, + SP::EG: CubeElement, { if options.groups != 1 { return Err(ConvLaunchError::Groups(options.groups)); @@ -226,7 +226,7 @@ where Ok(permute(out, &[0, 3, 1, 2])) } -pub(crate) fn has_tf32(c: &JitTensor) -> bool { +pub(crate) fn has_tf32(c: &CubeTensor) -> bool { c.client .properties() .feature_enabled(Feature::Type(Elem::Float(FloatKind::TF32))) diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/selection.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/selection.rs index f12f4fdc3c..ebccde8fca 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/selection.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/selection.rs @@ -2,7 +2,7 @@ use super::{ algorithm::{Algorithm, ImplicitCmmaConv}, precision::ConvPrecision, }; -use crate::JitRuntime; +use crate::CubeRuntime; use cubecl::linalg::matmul::components::{CompleteStageTiling, MatmulSelection, MatmulSize}; pub struct ConvSelection { @@ -10,7 +10,8 @@ pub struct ConvSelection { } pub trait ConvSelector { - fn select_kernel(plane_dim: u32) -> (A::Selection, A::Input); + fn select_kernel(plane_dim: u32) + -> (A::Selection, A::Input); } /// Large m stage size for the usual case where `batch_size * out_h * out_w` is significantly larger @@ -21,7 +22,7 @@ pub struct Large; pub struct Balanced; impl ConvSelector for Large { - fn select_kernel( + fn select_kernel( plane_dim: u32, ) -> ( ::Selection, @@ -48,7 +49,7 @@ impl ConvSelector for Large { } impl ConvSelector for Balanced { - fn select_kernel( + fn select_kernel( plane_dim: u32, ) -> ( ::Selection, diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/im2col.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/im2col.rs index 09ce56898b..e0cc896754 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/im2col.rs @@ -12,8 +12,8 @@ use crate::{ AddOp, }, ops::{numeric::empty_device, reshape, swap_dims}, - tensor::JitTensor, - FloatElement, JitRuntime, + tensor::CubeTensor, + CubeRuntime, FloatElement, }; #[derive(CubeLaunch)] @@ -132,14 +132,14 @@ pub(crate) fn batches_per_run( Ok(1) } -fn im2col( - input: JitTensor, +fn im2col( + input: CubeTensor, options: ConvOptions<2>, kernel_h: usize, kernel_w: usize, out_h: usize, out_w: usize, -) -> JitTensor { +) -> CubeTensor { let input = into_contiguous(input); let [batch_size, in_channels, _, _] = input.shape.dims(); @@ -196,12 +196,12 @@ fn im2col( /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -pub fn conv2d_im2col( - input: JitTensor, - weight: JitTensor, - bias: Option>, +pub fn conv2d_im2col( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvOptions<2>, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { let [batch_size, in_channels, in_height, in_width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let groups = options.groups; @@ -268,12 +268,12 @@ pub fn conv2d_im2col( Ok(out) } -fn execute_1x1_kernel( - input: JitTensor, - weight: JitTensor, - bias: Option>, +fn execute_1x1_kernel( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvOptions<2>, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { let [batch_size, _, height, width] = input.shape.dims(); let [out_channels, in_c_per_grp, _, _] = weight.shape.dims(); let groups = options.groups; @@ -295,10 +295,10 @@ fn execute_1x1_kernel( Ok(swap_dims(out, 0, 1)) } -fn execute( - input: JitTensor, - weight: JitTensor, - out: JitTensor, +fn execute( + input: CubeTensor, + weight: CubeTensor, + out: CubeTensor, options: ConvOptions<2>, out_h: usize, out_w: usize, diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/implicit_gemm.rs index a9bada754c..1aa224c3f7 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/implicit_gemm.rs @@ -18,8 +18,8 @@ use crate::{ numeric::{empty_device, zeros_device}, permute, }, - tensor::JitTensor, - FloatElement, JitRuntime, + tensor::CubeTensor, + CubeRuntime, FloatElement, }; use super::nchw_to_nhwc; @@ -31,12 +31,12 @@ use super::nchw_to_nhwc; /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -pub fn conv2d_implicit_gemm( - input: JitTensor, - weight: JitTensor, - bias: Option>, +pub fn conv2d_implicit_gemm( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvOptions<2>, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { let is_tf32 = F::as_elem_native_unchecked() == Elem::Float(FloatKind::F32) && input .client @@ -636,7 +636,7 @@ fn load_weight_tile( } #[allow(clippy::too_many_arguments)] -pub(crate) fn check_availability( +pub(crate) fn check_availability( batch_size: usize, in_channels: usize, out_channels: usize, @@ -738,7 +738,7 @@ fn padded_batch_size(batch_size: usize, out_h: usize, out_w: usize) -> usize { batch_size.div_ceil(target) * target } -fn find_cmma_size( +fn find_cmma_size( client: &ComputeClient, gemm_m: u32, gemm_k: u32, @@ -752,7 +752,7 @@ fn find_cmma_size( .map(|(m, k, n)| (m as u32, n as u32, k as u32)) } -fn supported_cmma_sizes( +fn supported_cmma_sizes( client: &ComputeClient, ) -> Vec<(u8, u8, u8)> { let (requested_sizes, matrix_elem) = match ( diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/layout_swap.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/layout_swap.rs index 7cbe09dbc0..d493e8a7f4 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/layout_swap.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/layout_swap.rs @@ -3,8 +3,8 @@ use cubecl::{prelude::*, CubeCount, CubeDim}; use crate::{ ops::{max_vectorization, numeric::empty_device}, - tensor::JitTensor, - JitElement, JitRuntime, + tensor::CubeTensor, + CubeElement, CubeRuntime, }; /// Efficiently transpose an NCHW tensor to NHWC for use in kernels that prefer NHWC for performance. @@ -18,7 +18,7 @@ use crate::{ /// /// The input in NHWC format /// -pub fn nchw_to_nhwc(input: JitTensor) -> JitTensor { +pub fn nchw_to_nhwc(input: CubeTensor) -> CubeTensor { let tiles_per_block = 8; let warp_size = 32; let tile_dim = 16; diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/transpose_direct.rs index a8cd1ceb7f..dc6cabfef6 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/transpose_direct.rs @@ -1,14 +1,14 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ - element::JitElement, + element::CubeElement, kernel::{conv::ConvLaunchError, into_contiguous}, ops::{ numeric::{empty_device, zeros_device}, reshape, }, - tensor::JitTensor, - JitRuntime, + tensor::CubeTensor, + CubeRuntime, }; use burn_tensor::{ops::ConvTransposeOptions, Shape}; @@ -121,12 +121,12 @@ fn conv_transpose2d_direct_kernel( /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -pub fn conv_transpose2d_direct( - input: JitTensor, - weight: JitTensor, - bias: Option>, +pub fn conv_transpose2d_direct( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvTransposeOptions<2>, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { let input = into_contiguous(input); let weight = into_contiguous(weight); let [batch_size, _, in_height, in_width] = input.shape.dims(); diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/tune/conv2d.rs index 36d12e2255..10307482c2 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/tune/conv2d.rs @@ -10,20 +10,20 @@ use crate::{ }, prng::random_uniform, }, - tensor::JitTensor, - FloatElement, JitAutotuneKey, JitRuntime, JitTuneId, + tensor::CubeTensor, + CubeAutotuneKey, CubeRuntime, CubeTuneId, FloatElement, }; /// Executes autotune on conv2d operations -pub fn conv2d_autotune( - input: JitTensor, - weights: JitTensor, - bias: Option>, +pub fn conv2d_autotune( + input: CubeTensor, + weights: CubeTensor, + bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> CubeTensor { let client = input.client.clone(); - static TUNER: LocalTuner = local_tuner!(); + static TUNER: LocalTuner = local_tuner!(); let tunables = TunableSet::new(create_key::, create_conv2d_input::) .with_tunable(conv2d_direct::) @@ -33,28 +33,28 @@ pub fn conv2d_autotune( .with_tunable(conv2d_gemm_cmma_balanced::); TUNER.execute( - &JitTuneId::new::(&input.device), + &CubeTuneId::new::(&input.device), &client, &tunables, (input, weights, bias, options), ) } -pub fn create_conv2d_input( - key: &JitAutotuneKey, - input: &JitTensor, - _weights: &JitTensor, - _bias: &Option>, +pub fn create_conv2d_input( + key: &CubeAutotuneKey, + input: &CubeTensor, + _weights: &CubeTensor, + _bias: &Option>, options: &ConvOptions<2>, ) -> ( - JitTensor, - JitTensor, - Option>, + CubeTensor, + CubeTensor, + Option>, ConvOptions<2>, ) { let device = &input.device; let key = match key { - JitAutotuneKey::Conv2d(key) => key, + CubeAutotuneKey::Conv2d(key) => key, _ => unreachable!(), }; @@ -73,12 +73,12 @@ pub fn create_conv2d_input( (input, weights, bias, options.clone()) } -fn create_key( - input: &JitTensor, - weights: &JitTensor, - bias: &Option>, +fn create_key( + input: &CubeTensor, + weights: &CubeTensor, + bias: &Option>, options: &ConvOptions<2>, -) -> JitAutotuneKey { +) -> CubeAutotuneKey { let [batch_size, in_channels, height, width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weights.shape.dims(); let ConvOptions { @@ -87,7 +87,7 @@ fn create_key( dilation, groups, } = options.clone(); - JitAutotuneKey::Conv2d(Conv2dAutotuneKey::new( + CubeAutotuneKey::Conv2d(Conv2dAutotuneKey::new( [kernel_h, kernel_w], stride, padding, diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/tune/conv_transpose2d.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/tune/conv_transpose2d.rs index df0159b75d..05c5dee9ea 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/tune/conv_transpose2d.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/tune/conv_transpose2d.rs @@ -6,49 +6,49 @@ use crate::{ conv::{conv_transpose2d_col2im, conv_transpose2d_direct}, prng::random_uniform, }, - tensor::JitTensor, - FloatElement, JitAutotuneKey, JitRuntime, JitTuneId, + tensor::CubeTensor, + CubeAutotuneKey, CubeRuntime, CubeTuneId, FloatElement, }; use super::ConvTranspose2dAutotuneKey; /// Executes autotune on conv2d operations -pub fn conv_transpose2d_autotune( - input: JitTensor, - weights: JitTensor, - bias: Option>, +pub fn conv_transpose2d_autotune( + input: CubeTensor, + weights: CubeTensor, + bias: Option>, options: ConvTransposeOptions<2>, -) -> JitTensor { +) -> CubeTensor { let client = input.client.clone(); - static TUNER: LocalTuner = local_tuner!(); + static TUNER: LocalTuner = local_tuner!(); let tune_set = TunableSet::new(create_key::, create_transpose2d_input::) .with_tunable(conv_transpose2d_direct::) .with_tunable(conv_transpose2d_col2im::); TUNER.execute( - &JitTuneId::new::(&input.device), + &CubeTuneId::new::(&input.device), &client, &tune_set, (input, weights, bias, options), ) } -pub fn create_transpose2d_input( - key: &JitAutotuneKey, - input: &JitTensor, - _weights: &JitTensor, - _bias: &Option>, +pub fn create_transpose2d_input( + key: &CubeAutotuneKey, + input: &CubeTensor, + _weights: &CubeTensor, + _bias: &Option>, options: &ConvTransposeOptions<2>, ) -> ( - JitTensor, - JitTensor, - Option>, + CubeTensor, + CubeTensor, + Option>, ConvTransposeOptions<2>, ) { let key = match key { - JitAutotuneKey::ConvTranspose2d(key) => key, + CubeAutotuneKey::ConvTranspose2d(key) => key, _ => unreachable!(), }; let device = &input.device; @@ -67,12 +67,12 @@ pub fn create_transpose2d_input( (input, weights, bias, options.clone()) } -fn create_key( - input: &JitTensor, - weights: &JitTensor, - bias: &Option>, +fn create_key( + input: &CubeTensor, + weights: &CubeTensor, + bias: &Option>, options: &ConvTransposeOptions<2>, -) -> JitAutotuneKey { +) -> CubeAutotuneKey { let [batch_size, in_channels, height, width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weights.shape.dims(); let ConvTransposeOptions { @@ -82,7 +82,7 @@ fn create_key( groups, padding_out, } = options.clone(); - JitAutotuneKey::ConvTranspose2d(ConvTranspose2dAutotuneKey::new( + CubeAutotuneKey::ConvTranspose2d(ConvTranspose2dAutotuneKey::new( [kernel_h, kernel_w], stride, padding, diff --git a/crates/burn-cubecl/src/kernel/conv/conv3d.rs b/crates/burn-cubecl/src/kernel/conv/conv3d.rs index a616c432b9..1a8a2bbe67 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv3d.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv3d.rs @@ -11,8 +11,8 @@ use crate::{ numeric::{empty_device, zeros_device}, reshape, }, - tensor::JitTensor, - FloatElement, JitRuntime, + tensor::CubeTensor, + CubeRuntime, FloatElement, }; #[derive(CubeLaunch)] @@ -139,12 +139,12 @@ fn conv3d_kernel( output[ABSOLUTE_POS] = sum; } -pub(crate) fn conv3d( - input: JitTensor, - weight: JitTensor, - bias: Option>, +pub(crate) fn conv3d( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvOptions<3>, -) -> JitTensor { +) -> CubeTensor { let input = into_contiguous(input); let weight = into_contiguous(weight); let [batch_size, _, in_depth, in_height, in_width] = input.shape.dims(); diff --git a/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs b/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs index 860b14ae6a..0f7919d7a8 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs @@ -1,14 +1,14 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ - element::JitElement, + element::CubeElement, kernel::into_contiguous, ops::{ numeric::{empty_device, zeros_device}, reshape, }, - tensor::JitTensor, - JitRuntime, + tensor::CubeTensor, + CubeRuntime, }; use burn_tensor::{ops::ConvTransposeOptions, Element, Shape}; @@ -145,12 +145,12 @@ fn conv_transpose3d_kernel( output[ABSOLUTE_POS] = sum; } -pub(crate) fn conv_transpose3d( - input: JitTensor, - weight: JitTensor, - bias: Option>, +pub(crate) fn conv_transpose3d( + input: CubeTensor, + weight: CubeTensor, + bias: Option>, options: ConvTransposeOptions<3>, -) -> JitTensor { +) -> CubeTensor { let input = into_contiguous(input); let weight = into_contiguous(weight); let [batch_size, _, in_depth, in_height, in_width] = input.shape.dims(); diff --git a/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs b/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs index 300d714335..12564f38c7 100644 --- a/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs @@ -15,8 +15,8 @@ use crate::{ numeric::{ones_device, zeros_device}, reshape, swap_dims, }, - tensor::JitTensor, - FloatElement, JitRuntime, + tensor::CubeTensor, + CubeRuntime, FloatElement, }; use super::ConvLaunchError; @@ -188,14 +188,14 @@ pub(crate) fn bilinear_interpolate( result } -pub(crate) fn deform_im2col( - input: JitTensor, - offset: JitTensor, - mask: Option>, +pub(crate) fn deform_im2col( + input: CubeTensor, + offset: CubeTensor, + mask: Option>, options: DeformConvOptions<2>, out_dims: (usize, usize), kernel_dims: (usize, usize), -) -> JitTensor { +) -> CubeTensor { let client = input.client.clone(); let device = input.device.clone(); @@ -257,14 +257,14 @@ pub(crate) fn deform_im2col( output } -pub(crate) fn deform_conv2d( - input: JitTensor, - offset: JitTensor, - weight: JitTensor, - mask: Option>, - bias: Option>, +pub(crate) fn deform_conv2d( + input: CubeTensor, + offset: CubeTensor, + weight: CubeTensor, + mask: Option>, + bias: Option>, options: DeformConvOptions<2>, -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { let input = into_contiguous(input); let offset = into_contiguous(offset); let weight = into_contiguous(weight); diff --git a/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs index 5840f4dc9f..9a168a3799 100644 --- a/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs @@ -20,8 +20,8 @@ use crate::{ numeric::{empty_device, ones_device, zeros_device}, reshape, swap_dims, }, - tensor::JitTensor, - FloatElement, IntElement, JitBackend, JitRuntime, + tensor::CubeTensor, + CubeBackend, CubeRuntime, FloatElement, IntElement, }; use super::{bilinear_interpolate, deform_im2col, index, ConvLaunchError}; @@ -29,26 +29,26 @@ use super::{bilinear_interpolate, deform_im2col, index, ConvLaunchError}; /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. #[allow(clippy::single_range_in_vec_init)] pub(crate) fn deform_conv2d_backward< - R: JitRuntime, + R: CubeRuntime, E: FloatElement, I: IntElement, BT: BoolElement, >( - input: JitTensor, - offset: JitTensor, - weight: JitTensor, - mask: Option>, - bias: Option>, - out_grad: JitTensor, + input: CubeTensor, + offset: CubeTensor, + weight: CubeTensor, + mask: Option>, + bias: Option>, + out_grad: CubeTensor, options: DeformConvOptions<2>, -) -> Result>, ConvLaunchError> { +) -> Result>, ConvLaunchError> { let [_, _, out_h, out_w] = out_grad.shape.dims(); let [_, _, kernel_h, kernel_w] = weight.shape.dims(); let gradient_bias = bias.map(|bias| { - let grad = JitBackend::::float_sum_dim(out_grad.clone(), 0); - let grad = JitBackend::::float_sum_dim(grad, 2); - let grad = JitBackend::::float_sum_dim(grad, 3); + let grad = CubeBackend::::float_sum_dim(out_grad.clone(), 0); + let grad = CubeBackend::::float_sum_dim(grad, 2); + let grad = CubeBackend::::float_sum_dim(grad, 3); reshape(grad, bias.shape) }); @@ -86,15 +86,15 @@ pub(crate) fn deform_conv2d_backward< )) } -fn compute_weight_grad( - input: JitTensor, - offset: JitTensor, - mask: Option>, - out_grad: JitTensor, +fn compute_weight_grad( + input: CubeTensor, + offset: CubeTensor, + mask: Option>, + out_grad: CubeTensor, options: DeformConvOptions<2>, kernel_dims: (usize, usize), out_dims: (usize, usize), -) -> Result, ConvLaunchError> { +) -> Result, ConvLaunchError> { let [_, in_channels, _, _] = input.shape.dims(); let [_, out_channels, _, _] = out_grad.shape.dims(); let (kernel_h, kernel_w) = kernel_dims; @@ -121,14 +121,14 @@ fn compute_weight_grad( )) } -type InputGradients = (JitTensor, JitTensor, Option>); +type InputGradients = (CubeTensor, CubeTensor, Option>); -fn backward_gradient_inputs( - image: JitTensor, - weight: JitTensor, - offset: JitTensor, - mask: Option>, - out_grad: JitTensor, +fn backward_gradient_inputs( + image: CubeTensor, + weight: CubeTensor, + offset: CubeTensor, + mask: Option>, + out_grad: CubeTensor, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), ) -> Result, ConvLaunchError> { @@ -182,14 +182,14 @@ fn backward_gradient_inputs( Ok((input_gradient, offset_gradient, mask_gradient)) } -fn compute_offset_and_mask_gradient( - columns: JitTensor, - image: JitTensor, - offset: JitTensor, - mask: Option>, +fn compute_offset_and_mask_gradient( + columns: CubeTensor, + image: CubeTensor, + offset: CubeTensor, + mask: Option>, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), -) -> Result<(JitTensor, Option>), ConvLaunchError> { +) -> Result<(CubeTensor, Option>), ConvLaunchError> { let client = offset.client.clone(); let device = offset.device.clone(); let (kernel_height, kernel_width) = kernel_dims; @@ -433,14 +433,14 @@ fn get_coordinate_weight( } } -fn compute_input_grad( - columns: JitTensor, - offset: JitTensor, - mask: Option>, +fn compute_input_grad( + columns: CubeTensor, + offset: CubeTensor, + mask: Option>, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), input_shape: Shape, -) -> JitTensor { +) -> CubeTensor { let client = offset.client.clone(); let device = offset.device.clone(); diff --git a/crates/burn-cubecl/src/kernel/index/flip.rs b/crates/burn-cubecl/src/kernel/index/flip.rs index a682a76eac..79aad23526 100644 --- a/crates/burn-cubecl/src/kernel/index/flip.rs +++ b/crates/burn-cubecl/src/kernel/index/flip.rs @@ -1,5 +1,5 @@ use crate::{ - element::JitElement, ops::numeric::empty_device, tensor::JitTensor, BoolElement, JitRuntime, + element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, BoolElement, CubeRuntime, }; use cubecl::{calculate_cube_count_elemwise, prelude::*}; @@ -33,10 +33,10 @@ fn flip_kernel( output[ABSOLUTE_POS] = input[offset_input]; } -pub(crate) fn flip( - tensor: JitTensor, +pub(crate) fn flip( + tensor: CubeTensor, indices: &[usize], -) -> JitTensor { +) -> CubeTensor { let output = empty_device::( tensor.client.clone(), tensor.device.clone(), @@ -45,11 +45,11 @@ pub(crate) fn flip( flip_on_output::(tensor, output, indices) } -pub(crate) fn flip_on_output( - tensor: JitTensor, - output: JitTensor, +pub(crate) fn flip_on_output( + tensor: CubeTensor, + output: CubeTensor, indices: &[usize], -) -> JitTensor { +) -> CubeTensor { let ndims = tensor.shape.num_dims(); let mut indices_sequence = SequenceArg::<'_, R, BT>::new(); diff --git a/crates/burn-cubecl/src/kernel/index/gather.rs b/crates/burn-cubecl/src/kernel/index/gather.rs index c1aa56072e..c0e7cbd3d9 100644 --- a/crates/burn-cubecl/src/kernel/index/gather.rs +++ b/crates/burn-cubecl/src/kernel/index/gather.rs @@ -1,4 +1,4 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use cubecl::frontend::{Numeric, Tensor, ABSOLUTE_POS}; use cubecl::linalg::tensor::index_offset_with_layout; use cubecl::CubeDim; @@ -32,11 +32,11 @@ fn gather_kernel( output[ABSOLUTE_POS] = input[offset]; } -pub(crate) fn gather( +pub(crate) fn gather( dim: usize, - tensor: JitTensor, - indices: JitTensor, -) -> JitTensor { + tensor: CubeTensor, + indices: CubeTensor, +) -> CubeTensor { let shape_output = indices.shape.clone(); let total_elem = shape_output.num_elements(); let output = empty_device::(tensor.client.clone(), tensor.device.clone(), shape_output); diff --git a/crates/burn-cubecl/src/kernel/index/repeat_dim.rs b/crates/burn-cubecl/src/kernel/index/repeat_dim.rs index b19f9e2b21..f50adda873 100644 --- a/crates/burn-cubecl/src/kernel/index/repeat_dim.rs +++ b/crates/burn-cubecl/src/kernel/index/repeat_dim.rs @@ -1,4 +1,4 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked)] @@ -19,11 +19,11 @@ fn repeat_dim_kernel(input: &Tensor, output: &mut Tensor output[ABSOLUTE_POS] = input[offset_input]; } -pub(crate) fn repeat_dim( - input: JitTensor, +pub(crate) fn repeat_dim( + input: CubeTensor, dim: usize, times: usize, -) -> JitTensor { +) -> CubeTensor { let mut shape = input.shape.clone(); // Create output handle diff --git a/crates/burn-cubecl/src/kernel/index/scatter.rs b/crates/burn-cubecl/src/kernel/index/scatter.rs index 4cca94f824..6f8d0f1f39 100644 --- a/crates/burn-cubecl/src/kernel/index/scatter.rs +++ b/crates/burn-cubecl/src/kernel/index/scatter.rs @@ -1,8 +1,8 @@ use crate::{ - element::JitElement, + element::CubeElement, kernel::{self}, - tensor::JitTensor, - IntElement, JitRuntime, + tensor::CubeTensor, + CubeRuntime, IntElement, }; use cubecl::prelude::*; use cubecl::{calculate_cube_count_elemwise, CubeDim}; @@ -65,12 +65,12 @@ fn scatter_kernel( } } -pub(crate) fn scatter( +pub(crate) fn scatter( dim: usize, - tensor: JitTensor, - indices: JitTensor, - value: JitTensor, -) -> JitTensor { + tensor: CubeTensor, + indices: CubeTensor, + value: CubeTensor, +) -> CubeTensor { let ndims = tensor.shape.num_dims(); let mut indices = kernel::into_contiguous(indices); let tensor = kernel::into_contiguous(tensor); diff --git a/crates/burn-cubecl/src/kernel/index/select.rs b/crates/burn-cubecl/src/kernel/index/select.rs index fe664ab420..e92a36f92a 100644 --- a/crates/burn-cubecl/src/kernel/index/select.rs +++ b/crates/burn-cubecl/src/kernel/index/select.rs @@ -1,4 +1,4 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use cubecl::prelude::*; use cubecl::{calculate_cube_count_elemwise, CubeDim}; @@ -28,11 +28,11 @@ fn select_kernel( output[ABSOLUTE_POS] = input[offset_input]; } -pub(crate) fn select( - tensor: JitTensor, +pub(crate) fn select( + tensor: CubeTensor, dim: usize, - indices: JitTensor, -) -> JitTensor { + indices: CubeTensor, +) -> CubeTensor { let ndims = tensor.shape.num_dims(); let mut shape_output = tensor.shape.clone(); shape_output.dims[dim] = indices.shape.dims[0]; diff --git a/crates/burn-cubecl/src/kernel/index/select_assign.rs b/crates/burn-cubecl/src/kernel/index/select_assign.rs index cd4c013f63..609f94f60b 100644 --- a/crates/burn-cubecl/src/kernel/index/select_assign.rs +++ b/crates/burn-cubecl/src/kernel/index/select_assign.rs @@ -1,4 +1,4 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, tensor::CubeTensor, CubeRuntime}; use cubecl::prelude::*; use cubecl::{calculate_cube_count_elemwise, CubeDim}; @@ -44,12 +44,12 @@ fn select_assign_kernel( } } -pub(crate) fn select_assign( - tensor: JitTensor, +pub(crate) fn select_assign( + tensor: CubeTensor, dim: usize, - indices: JitTensor, - value: JitTensor, -) -> JitTensor { + indices: CubeTensor, + value: CubeTensor, +) -> CubeTensor { let ndims = tensor.shape.num_dims(); let tensor = match tensor.can_mut() { true => tensor, diff --git a/crates/burn-cubecl/src/kernel/index/slice.rs b/crates/burn-cubecl/src/kernel/index/slice.rs index bca8e00dd9..8c02b3f954 100644 --- a/crates/burn-cubecl/src/kernel/index/slice.rs +++ b/crates/burn-cubecl/src/kernel/index/slice.rs @@ -1,13 +1,13 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use burn_tensor::Shape; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use std::ops::Range; /// Slice a jit tensor with a set of ranges -pub fn slice( - tensor: JitTensor, +pub fn slice( + tensor: CubeTensor, indices: &[Range], -) -> JitTensor { +) -> CubeTensor { let mut dims = tensor.shape.dims.clone(); let mut offset_start = 0u64; let mut offset_end = 0u64; @@ -26,7 +26,7 @@ pub fn slice( if offset_start % memory_offset_alignment == 0u64 && offset_end % memory_offset_alignment == 0u64 { - JitTensor::new( + CubeTensor::new( tensor.client, tensor .handle @@ -69,11 +69,11 @@ fn slice_kernel( output[ABSOLUTE_POS] = input[offset_input]; } -pub(crate) fn slice_on_output( - tensor: JitTensor, - output: JitTensor, +pub(crate) fn slice_on_output( + tensor: CubeTensor, + output: CubeTensor, indices: &[Range], -) -> JitTensor { +) -> CubeTensor { let ndims = tensor.shape.num_dims(); let mut indices_sequence = SequenceArg::::new(); diff --git a/crates/burn-cubecl/src/kernel/index/slice_assign.rs b/crates/burn-cubecl/src/kernel/index/slice_assign.rs index ca3c9adf6e..1d9bc0f4ee 100644 --- a/crates/burn-cubecl/src/kernel/index/slice_assign.rs +++ b/crates/burn-cubecl/src/kernel/index/slice_assign.rs @@ -1,4 +1,4 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, tensor::CubeTensor, CubeRuntime}; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use std::ops::Range; @@ -25,11 +25,11 @@ fn slice_assign_kernel( input[offset_input] = value[offset_value]; } -pub(crate) fn slice_assign( - tensor: JitTensor, +pub(crate) fn slice_assign( + tensor: CubeTensor, indices: &[Range], - value: JitTensor, -) -> JitTensor { + value: CubeTensor, +) -> CubeTensor { let tensor = match tensor.can_mut() { true => tensor, false => tensor.copy(), diff --git a/crates/burn-cubecl/src/kernel/interpolate/base.rs b/crates/burn-cubecl/src/kernel/interpolate/base.rs index c3d3a51b21..01db39e0a8 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/base.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/base.rs @@ -1,6 +1,6 @@ use crate::{ - kernel::into_contiguous, ops::numeric::empty_device, tensor::JitTensor, FloatElement, - JitRuntime, + kernel::into_contiguous, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime, + FloatElement, }; use burn_tensor::{ ops::{InterpolateMode, InterpolateOptions}, @@ -15,11 +15,11 @@ use super::{ /// Interpolate operation /// /// Supports nearest, bilinear and bicubic modes -pub fn interpolate( - input: JitTensor, +pub fn interpolate( + input: CubeTensor, output_size: [usize; 2], options: InterpolateOptions, -) -> JitTensor { +) -> CubeTensor { let input = into_contiguous(input); let [batch_size, channels, _, _] = input.shape.dims(); let [out_height, out_width] = output_size; @@ -37,17 +37,17 @@ pub fn interpolate( /// Backward interpolate operation /// /// Note: only nearest mode is supported -pub fn interpolate_backward( - input: JitTensor, - out_grad: JitTensor, +pub fn interpolate_backward( + input: CubeTensor, + out_grad: CubeTensor, _output_size: [usize; 2], options: InterpolateOptions, -) -> JitTensor { +) -> CubeTensor { let out_grad = into_contiguous(out_grad); let output_shape = input.shape.clone(); let num_elems = input.shape.num_elements(); let buffer = input.client.empty(num_elems * core::mem::size_of::()); - let output = JitTensor::new_contiguous( + let output = CubeTensor::new_contiguous( input.client.clone(), input.device.clone(), output_shape, diff --git a/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs b/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs index 3f77ef1302..bbf3223060 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs @@ -1,6 +1,6 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; -use crate::{tensor::JitTensor, FloatElement, JitRuntime}; +use crate::{tensor::CubeTensor, CubeRuntime, FloatElement}; #[cube(launch)] fn interpolate_bicubic_kernel(input: &Tensor, output: &mut Tensor) { @@ -121,10 +121,10 @@ fn cubic_convolution_2(x: F, a: F) -> F { conv - F::new(4.0) * a } -pub(crate) fn interpolate_bicubic_launch( - input: JitTensor, - output: JitTensor, -) -> JitTensor { +pub(crate) fn interpolate_bicubic_launch( + input: CubeTensor, + output: CubeTensor, +) -> CubeTensor { let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); diff --git a/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs b/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs index f0cb95b536..b866c644a2 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs @@ -1,6 +1,6 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; -use crate::{tensor::JitTensor, FloatElement, JitRuntime}; +use crate::{tensor::CubeTensor, CubeRuntime, FloatElement}; #[cube(launch)] fn interpolate_bilinear_kernel(input: &Tensor, output: &mut Tensor) { @@ -79,10 +79,10 @@ fn interpolate_bilinear_kernel(input: &Tensor, output: &mut Tensor< output[ABSOLUTE_POS] = p_a + p_b + p_c + p_d; } -pub(crate) fn interpolate_bilinear_launch( - input: JitTensor, - output: JitTensor, -) -> JitTensor { +pub(crate) fn interpolate_bilinear_launch( + input: CubeTensor, + output: CubeTensor, +) -> CubeTensor { let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); diff --git a/crates/burn-cubecl/src/kernel/interpolate/nearest.rs b/crates/burn-cubecl/src/kernel/interpolate/nearest.rs index 0e6ba32552..8bba333b97 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/nearest.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/nearest.rs @@ -1,6 +1,6 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; -use crate::{tensor::JitTensor, FloatElement, JitRuntime}; +use crate::{tensor::CubeTensor, CubeRuntime, FloatElement}; #[cube(launch_unchecked)] fn interpolate_nearest_kernel(input: &Tensor, output: &mut Tensor) { @@ -31,10 +31,10 @@ fn interpolate_nearest_kernel(input: &Tensor, output: &mut Tensor( - input: JitTensor, - output: JitTensor, -) -> JitTensor { +pub(crate) fn interpolate_nearest_launch( + input: CubeTensor, + output: CubeTensor, +) -> CubeTensor { let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); diff --git a/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs b/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs index f0442ec92e..c9e0d9a97a 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs @@ -1,6 +1,6 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; -use crate::{tensor::JitTensor, FloatElement, JitRuntime}; +use crate::{tensor::CubeTensor, CubeRuntime, FloatElement}; #[cube(launch_unchecked)] fn interpolate_nearest_backward_kernel(grad: &Tensor, output: &mut Tensor) { @@ -55,10 +55,10 @@ fn end_index(input_index: u32, output_size: u32, input_size: u32) -> u Min::min(output_size, index) } -pub(crate) fn interpolate_nearest_backward_launch( - out_grad: JitTensor, - output: JitTensor, -) -> JitTensor { +pub(crate) fn interpolate_nearest_backward_launch( + out_grad: CubeTensor, + output: CubeTensor, +) -> CubeTensor { let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); diff --git a/crates/burn-cubecl/src/kernel/mask/base.rs b/crates/burn-cubecl/src/kernel/mask/base.rs index d37c6e05bb..44b190f3ec 100644 --- a/crates/burn-cubecl/src/kernel/mask/base.rs +++ b/crates/burn-cubecl/src/kernel/mask/base.rs @@ -1,12 +1,12 @@ use super::{mask_where::MaskWhereStrategy, MaskFillStrategy}; -use crate::{element::JitElement, tensor::JitTensor, BoolElement, JitRuntime}; +use crate::{element::CubeElement, tensor::CubeTensor, BoolElement, CubeRuntime}; /// Execute the mask fill kernel. -pub(crate) fn mask_fill_auto( - tensor: JitTensor, - mask: JitTensor, +pub(crate) fn mask_fill_auto( + tensor: CubeTensor, + mask: CubeTensor, value: E, -) -> JitTensor { +) -> CubeTensor { let strategy = if tensor.can_mut() { MaskFillStrategy::Inplace } else { @@ -17,11 +17,11 @@ pub(crate) fn mask_fill_auto( } /// Execute the mask where kernel. -pub(crate) fn mask_where_auto( - tensor: JitTensor, - mask: JitTensor, - value: JitTensor, -) -> JitTensor { +pub(crate) fn mask_where_auto( + tensor: CubeTensor, + mask: CubeTensor, + value: CubeTensor, +) -> CubeTensor { let strategy = if tensor.can_mut_broadcast(&value) { MaskWhereStrategy::InplaceLhs } else if value.can_mut_broadcast(&tensor) { diff --git a/crates/burn-cubecl/src/kernel/mask/mask_fill.rs b/crates/burn-cubecl/src/kernel/mask/mask_fill.rs index 95096c7994..4a3c1c2cdf 100644 --- a/crates/burn-cubecl/src/kernel/mask/mask_fill.rs +++ b/crates/burn-cubecl/src/kernel/mask/mask_fill.rs @@ -1,10 +1,10 @@ use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*}; use crate::{ - element::JitElement, + element::CubeElement, ops::{max_vectorization, numeric::empty_device}, - tensor::JitTensor, - BoolElement, JitRuntime, + tensor::CubeTensor, + BoolElement, CubeRuntime, }; #[cube(launch)] @@ -58,23 +58,23 @@ pub enum MaskFillStrategy { } /// Execute the mask fill kernel with the given strategy. -pub fn mask_fill( - input: JitTensor, - mask: JitTensor, +pub fn mask_fill( + input: CubeTensor, + mask: CubeTensor, value: E, strategy: MaskFillStrategy, -) -> JitTensor { +) -> CubeTensor { match strategy { MaskFillStrategy::Readonly => mask_fill_readonly::(input, mask, value), MaskFillStrategy::Inplace => mask_fill_inplace::(input, mask, value), } } -fn mask_fill_readonly( - input: JitTensor, - mask: JitTensor, +fn mask_fill_readonly( + input: CubeTensor, + mask: CubeTensor, value: EI, -) -> JitTensor { +) -> CubeTensor { let ndims = input.shape.num_dims(); let output = empty_device::( input.client.clone(), @@ -100,11 +100,11 @@ fn mask_fill_readonly( output } -fn mask_fill_inplace( - input: JitTensor, - mask: JitTensor, +fn mask_fill_inplace( + input: CubeTensor, + mask: CubeTensor, value: EI, -) -> JitTensor { +) -> CubeTensor { let ndims = input.shape.num_dims(); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); diff --git a/crates/burn-cubecl/src/kernel/mask/mask_where.rs b/crates/burn-cubecl/src/kernel/mask/mask_where.rs index 99384fde98..1c91b0e4ca 100644 --- a/crates/burn-cubecl/src/kernel/mask/mask_where.rs +++ b/crates/burn-cubecl/src/kernel/mask/mask_where.rs @@ -1,10 +1,10 @@ use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*}; use crate::{ - element::JitElement, + element::CubeElement, ops::{max_vectorization, numeric::empty_device}, - tensor::JitTensor, - BoolElement, JitRuntime, + tensor::CubeTensor, + BoolElement, CubeRuntime, }; #[cube(launch)] @@ -65,12 +65,12 @@ pub enum MaskWhereStrategy { } /// Execute the mask where kernel with the given strategy. -pub fn mask_where( - input: JitTensor, - mask: JitTensor, - value: JitTensor, +pub fn mask_where( + input: CubeTensor, + mask: CubeTensor, + value: CubeTensor, strategy: MaskWhereStrategy, -) -> JitTensor { +) -> CubeTensor { match strategy { MaskWhereStrategy::Readonly => mask_where_readonly::(input, mask, value), MaskWhereStrategy::InplaceLhs => mask_where_inplace::(input, mask, value, false), @@ -78,11 +78,11 @@ pub fn mask_where( } } -fn mask_where_readonly( - input: JitTensor, - mask: JitTensor, - value: JitTensor, -) -> JitTensor { +fn mask_where_readonly( + input: CubeTensor, + mask: CubeTensor, + value: CubeTensor, +) -> CubeTensor { let ndims = input.shape.num_dims(); let output = empty_device::( input.client.clone(), @@ -108,12 +108,12 @@ fn mask_where_readonly( output } -fn mask_where_inplace( - input: JitTensor, - mask: JitTensor, - value: JitTensor, +fn mask_where_inplace( + input: CubeTensor, + mask: CubeTensor, + value: CubeTensor, reverse: bool, -) -> JitTensor { +) -> CubeTensor { let ndims = input.shape.num_dims(); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); diff --git a/crates/burn-cubecl/src/kernel/matmul/base.rs b/crates/burn-cubecl/src/kernel/matmul/base.rs index 611f1e32d4..e4e7f4ed0c 100644 --- a/crates/burn-cubecl/src/kernel/matmul/base.rs +++ b/crates/burn-cubecl/src/kernel/matmul/base.rs @@ -1,7 +1,7 @@ use cubecl::linalg::matmul::kernels::MatmulLaunchError; use super::init_matmul_output; -use crate::{tensor::JitTensor, FloatElement, JitRuntime}; +use crate::{tensor::CubeTensor, CubeRuntime, FloatElement}; #[cfg(feature = "autotune")] use super::matmul_autotune; @@ -27,12 +27,12 @@ impl Default for MatmulStrategy { } /// Launch a matmul kernel using the given strategy. -pub fn matmul( - lhs: JitTensor, - rhs: JitTensor, - out: Option>, +pub fn matmul( + lhs: CubeTensor, + rhs: CubeTensor, + out: Option>, strategy: MatmulStrategy, -) -> Result, MatmulLaunchError> { +) -> Result, MatmulLaunchError> { match strategy { MatmulStrategy::Cube => { let out = out.unwrap_or_else(|| init_matmul_output::(&lhs, &rhs)); diff --git a/crates/burn-cubecl/src/kernel/matmul/tune/base.rs b/crates/burn-cubecl/src/kernel/matmul/tune/base.rs index dacd2693b9..69279ead63 100644 --- a/crates/burn-cubecl/src/kernel/matmul/tune/base.rs +++ b/crates/burn-cubecl/src/kernel/matmul/tune/base.rs @@ -8,19 +8,19 @@ use crate::{ element::FloatElement, kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, ops::numeric::empty_device, - tensor::JitTensor, - tune_key::JitAutotuneKey, - JitRuntime, JitTuneId, + tensor::CubeTensor, + tune_key::CubeAutotuneKey, + CubeRuntime, CubeTuneId, }; use super::key::create_key; -fn matmul_input_gen( - _key: &JitAutotuneKey, - lhs: &JitTensor, - rhs: &JitTensor, - out: &JitTensor, -) -> (JitTensor, JitTensor, JitTensor) { +fn matmul_input_gen( + _key: &CubeAutotuneKey, + lhs: &CubeTensor, + rhs: &CubeTensor, + out: &CubeTensor, +) -> (CubeTensor, CubeTensor, CubeTensor) { let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); let lhs = random_like_uniform(lhs, random_bounds.0, random_bounds.1); let rhs = random_like_uniform(rhs, random_bounds.0, random_bounds.1); @@ -31,16 +31,16 @@ fn matmul_input_gen( } /// Executes autotune on matmul operations -pub fn matmul_autotune( - lhs: JitTensor, - rhs: JitTensor, - out: Option>, -) -> JitTensor { +pub fn matmul_autotune( + lhs: CubeTensor, + rhs: CubeTensor, + out: Option>, +) -> CubeTensor { let output = out.unwrap_or_else(|| init_matmul_output::(&lhs, &rhs)); let client = lhs.client.clone(); - static TUNER: LocalTuner = local_tuner!(); + static TUNER: LocalTuner = local_tuner!(); let tunables = TunableSet::new(create_key::, matmul_input_gen::) .with_tunable(matmul_tiling2d::) @@ -48,7 +48,7 @@ pub fn matmul_autotune( .with_tunable(matmul_simple::); TUNER.execute( - &JitTuneId::new::(&lhs.device), + &CubeTuneId::new::(&lhs.device), &client, &tunables, (lhs, rhs, output.clone()), @@ -57,10 +57,10 @@ pub fn matmul_autotune( output } -fn matmul_accelerated( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, +fn matmul_accelerated( + lhs: CubeTensor, + rhs: CubeTensor, + out: CubeTensor, ) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Standard, @@ -72,10 +72,10 @@ fn matmul_accelerated( .map_err(|err| format!("{err:?}")) } -fn matmul_tiling2d( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, +fn matmul_tiling2d( + lhs: CubeTensor, + rhs: CubeTensor, + out: CubeTensor, ) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Tiling2D(Tiling2dConfig::default()), @@ -87,10 +87,10 @@ fn matmul_tiling2d( .map_err(|err| format!("{err:?}")) } -fn matmul_simple( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, +fn matmul_simple( + lhs: CubeTensor, + rhs: CubeTensor, + out: CubeTensor, ) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Simple, diff --git a/crates/burn-cubecl/src/kernel/matmul/tune/key.rs b/crates/burn-cubecl/src/kernel/matmul/tune/key.rs index 44cb079399..c75772f30a 100644 --- a/crates/burn-cubecl/src/kernel/matmul/tune/key.rs +++ b/crates/burn-cubecl/src/kernel/matmul/tune/key.rs @@ -1,4 +1,4 @@ -use crate::{tensor::JitTensor, FloatElement, JitAutotuneKey, JitRuntime}; +use crate::{tensor::CubeTensor, CubeAutotuneKey, CubeRuntime, FloatElement}; use burn_tensor::{DType, Shape}; use core::fmt::Debug; use cubecl::AutotuneKey; @@ -47,12 +47,12 @@ impl MatmulAutotuneKey { } } -pub(crate) fn create_key( - lhs: &JitTensor, - rhs: &JitTensor, - _out: &JitTensor, -) -> JitAutotuneKey { - JitAutotuneKey::Matmul(MatmulAutotuneKey::from_shape( +pub(crate) fn create_key( + lhs: &CubeTensor, + rhs: &CubeTensor, + _out: &CubeTensor, +) -> CubeAutotuneKey { + CubeAutotuneKey::Matmul(MatmulAutotuneKey::from_shape( &lhs.shape, &rhs.shape, E::dtype(), diff --git a/crates/burn-cubecl/src/kernel/matmul/utils.rs b/crates/burn-cubecl/src/kernel/matmul/utils.rs index fa65ce60d3..c0f2f13a39 100644 --- a/crates/burn-cubecl/src/kernel/matmul/utils.rs +++ b/crates/burn-cubecl/src/kernel/matmul/utils.rs @@ -1,15 +1,15 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use burn_tensor::Shape; /// Creates an empty output tensor with matmul output shape -pub fn init_matmul_output( - lhs: &JitTensor, - rhs: &JitTensor, -) -> JitTensor { +pub fn init_matmul_output( + lhs: &CubeTensor, + rhs: &CubeTensor, +) -> CubeTensor { empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out(lhs, rhs)) } -pub(crate) fn shape_out(lhs: &JitTensor, rhs: &JitTensor) -> Shape { +pub(crate) fn shape_out(lhs: &CubeTensor, rhs: &CubeTensor) -> Shape { let ndims = lhs.shape.num_dims(); let mut shape_out = vec![0; ndims]; lhs.shape diff --git a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs index 564c99ab68..7ed6b94c3d 100644 --- a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs @@ -1,4 +1,4 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use burn_tensor::Shape; use cubecl::{calculate_cube_count_elemwise, prelude::*}; @@ -76,10 +76,10 @@ fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { } } -pub(crate) fn adaptive_avg_pool2d( - input: JitTensor, +pub(crate) fn adaptive_avg_pool2d( + input: CubeTensor, output_size: [usize; 2], -) -> JitTensor { +) -> CubeTensor { let [batch_size, channels, _, _] = input.shape.dims(); let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]); diff --git a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs index 1552389f0a..68540b74e3 100644 --- a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs @@ -1,4 +1,4 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, tensor::CubeTensor, CubeRuntime}; use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch)] @@ -80,14 +80,14 @@ fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { } } -pub(crate) fn adaptive_avg_pool2d_backward( - x: JitTensor, - out_grad: JitTensor, -) -> JitTensor { +pub(crate) fn adaptive_avg_pool2d_backward( + x: CubeTensor, + out_grad: CubeTensor, +) -> CubeTensor { let output_shape = x.shape.clone(); let num_elems = output_shape.num_elements(); let output_buffer = x.client.empty(num_elems * core::mem::size_of::()); - let output = JitTensor::new_contiguous( + let output = CubeTensor::new_contiguous( x.client.clone(), x.device.clone(), output_shape, diff --git a/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs index 964dba67f4..682c4a8c55 100644 --- a/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs @@ -1,7 +1,7 @@ use super::pool2d::{ pool2d_direct, Pool2dDirectArgsLaunch, Pool2dDirectStrategy, Pool2dDirectStrategyFamily, }; -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use burn_tensor::{ops::conv::calculate_pool_output_size, Shape}; use cubecl::prelude::*; use cubecl::{calculate_cube_count_elemwise, prelude::ScalarArg, CubeDim}; @@ -65,13 +65,13 @@ impl Pool2dDirectStrategy for AvgPoolStrategy { } } -pub(crate) fn avg_pool2d( - x: JitTensor, +pub(crate) fn avg_pool2d( + x: CubeTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, -) -> JitTensor { +) -> CubeTensor { let [batch_size, channels, _, _] = x.shape.dims(); let dilation = 1; diff --git a/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs index d2a5a21d0a..995e20c267 100644 --- a/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs @@ -1,6 +1,6 @@ use crate::{ - element::JitElement, kernel::into_contiguous, ops::numeric::empty_device, tensor::JitTensor, - JitRuntime, + element::CubeElement, kernel::into_contiguous, ops::numeric::empty_device, tensor::CubeTensor, + CubeRuntime, }; use cubecl::{calculate_cube_count_elemwise, prelude::*}; @@ -108,14 +108,14 @@ fn loop_ranges( (oh_start, oh_end, ow_start, ow_end) } -pub(crate) fn avg_pool2d_backward( - x: JitTensor, - grad: JitTensor, +pub(crate) fn avg_pool2d_backward( + x: CubeTensor, + grad: CubeTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, -) -> JitTensor { +) -> CubeTensor { let grad = into_contiguous(grad); let dilation = 1; diff --git a/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs index 29923ff608..c978744142 100644 --- a/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs @@ -1,7 +1,7 @@ use super::pool2d::{ pool2d_direct, Pool2dDirectArgsLaunch, Pool2dDirectStrategy, Pool2dDirectStrategyFamily, }; -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use burn_tensor::{ops::conv::calculate_pool_output_size, Shape}; use cubecl::{calculate_cube_count_elemwise, prelude::*, CubeDim}; @@ -86,13 +86,13 @@ impl Pool2dDirectStrategy for MaxPoolWithIndicesStrategy { } } -pub(crate) fn max_pool2d( - x: JitTensor, +pub(crate) fn max_pool2d( + x: CubeTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], -) -> JitTensor { +) -> CubeTensor { let [batch_size, channels, _, _] = x.shape.dims(); let size_0 = calculate_pool_output_size( @@ -138,13 +138,13 @@ pub(crate) fn max_pool2d( output } -pub(crate) fn max_pool2d_with_indices( - x: JitTensor, +pub(crate) fn max_pool2d_with_indices( + x: CubeTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], -) -> (JitTensor, JitTensor) { +) -> (CubeTensor, CubeTensor) { let [batch_size, channels, _, _] = x.shape.dims(); let size_0 = calculate_pool_output_size( diff --git a/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs index 40259c4573..81f2b2c1b8 100644 --- a/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs @@ -1,6 +1,6 @@ use crate::{ - element::JitElement, kernel::into_contiguous, ops::numeric::empty_device, tensor::JitTensor, - IntElement, JitRuntime, + element::CubeElement, kernel::into_contiguous, ops::numeric::empty_device, tensor::CubeTensor, + CubeRuntime, IntElement, }; use cubecl::{calculate_cube_count_elemwise, prelude::*}; @@ -73,15 +73,15 @@ fn loop_ranges( (oh_start, oh_end, ow_start, ow_end) } -pub(crate) fn max_pool2d_with_indices_backward( - x: JitTensor, - grad: JitTensor, - indices: JitTensor, +pub(crate) fn max_pool2d_with_indices_backward( + x: CubeTensor, + grad: CubeTensor, + indices: CubeTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], -) -> JitTensor { +) -> CubeTensor { let grad = into_contiguous(grad); let indices = into_contiguous(indices); diff --git a/crates/burn-cubecl/src/kernel/prng/base.rs b/crates/burn-cubecl/src/kernel/prng/base.rs index e4b35864cf..2a51b0a00c 100644 --- a/crates/burn-cubecl/src/kernel/prng/base.rs +++ b/crates/burn-cubecl/src/kernel/prng/base.rs @@ -1,6 +1,6 @@ use cubecl::prelude::*; -use crate::{ops::numeric::empty_device, tensor::JitTensor, JitElement, JitRuntime, SEED}; +use crate::{ops::numeric::empty_device, tensor::CubeTensor, CubeElement, CubeRuntime, SEED}; use burn_common::rand::get_seeded_rng; use burn_tensor::Shape; use rand::Rng; @@ -8,11 +8,11 @@ use rand::Rng; pub(crate) const N_VALUES_PER_THREAD: usize = 128; /// Pseudo-random generator -pub(crate) fn random, R: JitRuntime, E: JitElement>( +pub(crate) fn random, R: CubeRuntime, E: CubeElement>( shape: Shape, device: &R::Device, prng: P, -) -> JitTensor { +) -> CubeTensor { let client = R::client(device); let output = empty_device::(client.clone(), device.clone(), shape); let seeds = get_seeds(); @@ -61,14 +61,14 @@ pub(crate) fn get_seeds() -> [u32; 4] { seeds.try_into().unwrap() } -pub(crate) trait PrngArgs: Send + Sync + 'static { +pub(crate) trait PrngArgs: Send + Sync + 'static { type Args: LaunchArg; fn args<'a, R: Runtime>(self) -> ::RuntimeArg<'a, R>; } #[cube] -pub(crate) trait PrngRuntime: Send + Sync + 'static + PrngArgs { +pub(crate) trait PrngRuntime: Send + Sync + 'static + PrngArgs { #[allow(clippy::too_many_arguments)] fn inner_loop( args: Self::Args, @@ -84,7 +84,7 @@ pub(crate) trait PrngRuntime: Send + Sync + 'static + PrngArgs } #[cube(launch)] -fn prng_kernel, E: JitElement>( +fn prng_kernel, E: CubeElement>( output: &mut Tensor, seed_0: u32, seed_1: u32, diff --git a/crates/burn-cubecl/src/kernel/prng/bernoulli.rs b/crates/burn-cubecl/src/kernel/prng/bernoulli.rs index 6a54097d3f..3454ca3418 100644 --- a/crates/burn-cubecl/src/kernel/prng/bernoulli.rs +++ b/crates/burn-cubecl/src/kernel/prng/bernoulli.rs @@ -3,8 +3,8 @@ use cubecl::prelude::*; use crate::{ kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2}, - tensor::JitTensor, - JitElement, JitRuntime, + tensor::CubeTensor, + CubeElement, CubeRuntime, }; use super::{random, PrngArgs, PrngRuntime}; @@ -15,7 +15,7 @@ pub(crate) struct Bernoulli { } #[cube] -impl PrngRuntime for Bernoulli { +impl PrngRuntime for Bernoulli { fn inner_loop( args: Bernoulli, write_index_base: u32, @@ -46,7 +46,7 @@ impl PrngRuntime for Bernoulli { } } -impl PrngArgs for Bernoulli { +impl PrngArgs for Bernoulli { type Args = Self; fn args<'a, R: Runtime>(self) -> BernoulliLaunch<'a, E, R> { @@ -55,10 +55,10 @@ impl PrngArgs for Bernoulli { } /// Pseudo-random generator with bernoulli distribution -pub fn random_bernoulli( +pub fn random_bernoulli( shape: Shape, device: &R::Device, probability: E, -) -> JitTensor { +) -> CubeTensor { random(shape, device, Bernoulli { probability }) } diff --git a/crates/burn-cubecl/src/kernel/prng/normal.rs b/crates/burn-cubecl/src/kernel/prng/normal.rs index 1453ce9715..4c63564b3e 100644 --- a/crates/burn-cubecl/src/kernel/prng/normal.rs +++ b/crates/burn-cubecl/src/kernel/prng/normal.rs @@ -5,8 +5,8 @@ use burn_tensor::Shape; use crate::{ kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2}, - tensor::JitTensor, - JitElement, JitRuntime, + tensor::CubeTensor, + CubeElement, CubeRuntime, }; use super::{random, PrngArgs, PrngRuntime}; @@ -18,7 +18,7 @@ pub(crate) struct Normal { } #[cube] -impl PrngRuntime for Normal { +impl PrngRuntime for Normal { fn inner_loop( args: Normal, write_index_base: u32, @@ -73,7 +73,7 @@ impl PrngRuntime for Normal { } } -impl PrngArgs for Normal { +impl PrngArgs for Normal { type Args = Self; fn args<'a, R: Runtime>(self) -> NormalLaunch<'a, E, R> { @@ -82,11 +82,11 @@ impl PrngArgs for Normal { } /// Pseudo-random generator with uniform distribution -pub fn random_normal( +pub fn random_normal( shape: Shape, device: &R::Device, mean: E, std: E, -) -> JitTensor { +) -> CubeTensor { random(shape, device, Normal { mean, std }) } diff --git a/crates/burn-cubecl/src/kernel/prng/uniform.rs b/crates/burn-cubecl/src/kernel/prng/uniform.rs index 18e8975377..be3197f55a 100644 --- a/crates/burn-cubecl/src/kernel/prng/uniform.rs +++ b/crates/burn-cubecl/src/kernel/prng/uniform.rs @@ -3,8 +3,8 @@ use cubecl::prelude::*; use crate::{ kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2}, - tensor::JitTensor, - JitElement, JitRuntime, + tensor::CubeTensor, + CubeElement, CubeRuntime, }; use super::{random, PrngArgs, PrngRuntime}; @@ -16,7 +16,7 @@ pub(crate) struct Uniform { } #[cube] -impl PrngRuntime for Uniform { +impl PrngRuntime for Uniform { fn inner_loop( args: Uniform, write_index_base: u32, @@ -54,7 +54,7 @@ impl PrngRuntime for Uniform { } } -impl PrngArgs for Uniform { +impl PrngArgs for Uniform { type Args = Self; fn args<'a, R: Runtime>(self) -> UniformLaunch<'a, E, R> { @@ -66,12 +66,12 @@ impl PrngArgs for Uniform { } /// Pseudo-random generator with uniform distribution -pub fn random_uniform( +pub fn random_uniform( shape: Shape, device: &R::Device, lower_bound: E, upper_bound: E, -) -> JitTensor { +) -> CubeTensor { random( shape, device, @@ -83,11 +83,11 @@ pub fn random_uniform( } /// Pseudo-random generator for uniform distribution, based on /// another tensor. -pub fn random_like_uniform( - tensor: &JitTensor, +pub fn random_like_uniform( + tensor: &CubeTensor, lower_bound: E, upper_bound: E, -) -> JitTensor { +) -> CubeTensor { random_uniform( tensor.shape.clone(), &tensor.device, diff --git a/crates/burn-cubecl/src/kernel/quantization/dequantize.rs b/crates/burn-cubecl/src/kernel/quantization/dequantize.rs index 270e32f854..c318baa276 100644 --- a/crates/burn-cubecl/src/kernel/quantization/dequantize.rs +++ b/crates/burn-cubecl/src/kernel/quantization/dequantize.rs @@ -1,6 +1,6 @@ -use crate::tensor::JitTensor; +use crate::tensor::CubeTensor; use crate::FloatElement; -use crate::{JitElement, JitRuntime}; +use crate::{CubeElement, CubeRuntime}; use burn_tensor::quantization::{QuantizationScheme, QuantizationType}; use burn_tensor::DType; use cubecl::calculate_cube_count_elemwise; @@ -107,10 +107,10 @@ pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( } } -pub(crate) fn dequantize_per_tensor(tensor: JitTensor) -> JitTensor +pub(crate) fn dequantize_per_tensor(tensor: CubeTensor) -> CubeTensor where - R: JitRuntime, - F: JitElement, + R: CubeRuntime, + F: CubeElement, { // The actual number of elements is 1/4 (four int8 values packed in a single u32) // so we choose a line size to match a valid input binding size. @@ -124,7 +124,7 @@ where let client = tensor.client.clone(); let handle = client.empty(num_out_elems * core::mem::size_of::()); - let output = JitTensor::new_contiguous( + let output = CubeTensor::new_contiguous( client.clone(), tensor.device.clone(), tensor.shape.clone(), @@ -165,9 +165,9 @@ where } /// Convert the tensor back to a higher precision data type. -pub fn dequantize(tensor: JitTensor) -> JitTensor +pub fn dequantize(tensor: CubeTensor) -> CubeTensor where - R: JitRuntime, + R: CubeRuntime, F: FloatElement, { dequantize_per_tensor::(tensor) diff --git a/crates/burn-cubecl/src/kernel/quantization/quantize.rs b/crates/burn-cubecl/src/kernel/quantization/quantize.rs index 0a7b0ea553..49565e9ca0 100644 --- a/crates/burn-cubecl/src/kernel/quantization/quantize.rs +++ b/crates/burn-cubecl/src/kernel/quantization/quantize.rs @@ -1,6 +1,6 @@ -use crate::tensor::JitTensor; +use crate::tensor::CubeTensor; use crate::FloatElement; -use crate::{IntElement, JitElement, JitRuntime}; +use crate::{CubeElement, CubeRuntime, IntElement}; use burn_tensor::quantization::{QuantizationScheme, QuantizationType}; use cubecl::calculate_cube_count_elemwise; use cubecl::prelude::*; @@ -157,14 +157,14 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( } pub(crate) fn quantize_per_tensor( - tensor: JitTensor, - scale: JitTensor, - offset: Option>, + tensor: CubeTensor, + scale: CubeTensor, + offset: Option>, scheme: QuantizationScheme, -) -> JitTensor +) -> CubeTensor where - R: JitRuntime, - F: JitElement, + R: CubeRuntime, + F: CubeElement, I: IntElement, { let ndims = tensor.shape.num_dims(); @@ -183,7 +183,7 @@ where // Scale and offset qparams are also packed in the tensor dat let handle = client .empty(output_num_elems + core::mem::size_of::() + core::mem::size_of::()); - let output = JitTensor::new_contiguous( + let output = CubeTensor::new_contiguous( client.clone(), tensor.device.clone(), tensor.shape.clone(), @@ -209,7 +209,7 @@ where } else { // Scale qparam is also packed in the tensor data let handle = client.empty(output_num_elems + core::mem::size_of::()); - let output = JitTensor::new_contiguous( + let output = CubeTensor::new_contiguous( client.clone(), tensor.device.clone(), tensor.shape.clone(), @@ -237,13 +237,13 @@ where /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. pub fn quantize( - tensor: JitTensor, + tensor: CubeTensor, scheme: &QuantizationScheme, - scale: JitTensor, - offset: Option>, -) -> JitTensor + scale: CubeTensor, + offset: Option>, +) -> CubeTensor where - R: JitRuntime, + R: CubeRuntime, F: FloatElement, I: IntElement, { diff --git a/crates/burn-cubecl/src/kernel/reduce/base.rs b/crates/burn-cubecl/src/kernel/reduce/base.rs index e54541d6a5..c7d24b26f0 100644 --- a/crates/burn-cubecl/src/kernel/reduce/base.rs +++ b/crates/burn-cubecl/src/kernel/reduce/base.rs @@ -1,6 +1,6 @@ #[cfg(feature = "autotune")] use super::{autotune_reduce, autotune_sum}; -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use burn_tensor::Shape; pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum}; use cubecl::reduce::shared_sum; @@ -11,10 +11,10 @@ use cubecl::reduce::shared_sum; /// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction. /// /// Return an error if the `client` doesn't support atomic add for the type `E`. -pub fn sum( - tensor: JitTensor, +pub fn sum( + tensor: CubeTensor, cube_count: SumStrategy, -) -> Result, cubecl::reduce::ReduceError> { +) -> Result, cubecl::reduce::ReduceError> { let client = tensor.client.clone(); let device = tensor.device.clone(); @@ -22,7 +22,7 @@ pub fn sum( SumStrategy::OneShot(cube_count) => { let handle = client.create(E::as_bytes(&[E::from_int(0)])); let output = - JitTensor::new_contiguous(client.clone(), device, [1].into(), handle, E::dtype()); + CubeTensor::new_contiguous(client.clone(), device, [1].into(), handle, E::dtype()); shared_sum::( &client, tensor.as_handle_ref(), @@ -66,10 +66,10 @@ impl Default for SumStrategy { /// /// If there is no error, the output is a tensor with decreasing strides /// where the shape of reduced dim is set to 1 but all shape are similar to the input. -pub fn reduce( - mut tensor: JitTensor, +pub fn reduce( + mut tensor: CubeTensor, strategy: ReduceStrategy, -) -> Result, cubecl::reduce::ReduceError> { +) -> Result, cubecl::reduce::ReduceError> { // In practice, it looks like starting by the axis with the smallest shape // and going in increasing order lead to the fastest calculation. let sorted_axis = argsort(&tensor.shape.dims); @@ -95,11 +95,16 @@ fn argsort(shape: &[usize]) -> Vec { /// /// If there is no error, the output is a tensor with decreasing strides /// where the shape of reduced dim is set to 1 but all shape are similar to the input. -pub fn reduce_dim( - input: JitTensor, +pub fn reduce_dim< + Run: CubeRuntime, + In: CubeElement, + Out: CubeElement, + Rd: cubecl::reduce::Reduce, +>( + input: CubeTensor, dim: usize, strategy: ReduceStrategy, -) -> Result, cubecl::reduce::ReduceError> { +) -> Result, cubecl::reduce::ReduceError> { let client = input.client.clone(); let output = init_reduce_output::(&input, dim).ok_or( cubecl::reduce::ReduceError::InvalidAxis { @@ -133,10 +138,10 @@ pub fn reduce_dim( - input: &JitTensor, +pub fn init_reduce_output( + input: &CubeTensor, dim: usize, -) -> Option> { +) -> Option> { (dim < input.shape.num_dims()).then(|| { let mut shape_out = input.shape.clone(); shape_out.dims[dim] = 1; diff --git a/crates/burn-cubecl/src/kernel/reduce/tune.rs b/crates/burn-cubecl/src/kernel/reduce/tune.rs index 8babe00c24..1f57e8ebdf 100644 --- a/crates/burn-cubecl/src/kernel/reduce/tune.rs +++ b/crates/burn-cubecl/src/kernel/reduce/tune.rs @@ -9,25 +9,25 @@ use cubecl::{ use serde::{Deserialize, Serialize}; use crate::{ - kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::JitTensor, - JitAutotuneKey, JitElement, JitRuntime, JitTuneId, + kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::CubeTensor, + CubeAutotuneKey, CubeElement, CubeRuntime, CubeTuneId, }; /// Executes autotune on reduce operations. pub fn autotune_reduce< - Run: JitRuntime, - In: JitElement, - Out: JitElement, + Run: CubeRuntime, + In: CubeElement, + Out: CubeElement, Rd: cubecl::reduce::Reduce, >( client: &ComputeClient, - input: JitTensor, - output: JitTensor, + input: CubeTensor, + output: CubeTensor, dim: usize, ) { use reduce_ops::*; - static TUNER: LocalTuner = local_tuner!(); + static TUNER: LocalTuner = local_tuner!(); let tunables = TunableSet::new(create_key::, reduce_input_gen::) .with_tunable(reduce::) @@ -36,7 +36,7 @@ pub fn autotune_reduce< .with_tunable(reduce_shared_plane::); TUNER.execute( - &JitTuneId::new::(&input.device), + &CubeTuneId::new::(&input.device), client, &tunables, (input, output, dim), @@ -56,7 +56,7 @@ pub struct ReduceAutotuneKey { } impl ReduceAutotuneKey { - pub(crate) fn generate(input: &JitTensor, axis: usize) -> Self { + pub(crate) fn generate(input: &CubeTensor, axis: usize) -> Self { let rank = input.shape.num_dims(); if axis > rank { @@ -83,12 +83,12 @@ impl ReduceAutotuneKey { } } -pub(crate) fn create_key( - input: &JitTensor, - _output: &JitTensor, +pub(crate) fn create_key( + input: &CubeTensor, + _output: &CubeTensor, dim: &usize, -) -> JitAutotuneKey { - JitAutotuneKey::Reduce(ReduceAutotuneKey::generate(input, *dim)) +) -> CubeAutotuneKey { + CubeAutotuneKey::Reduce(ReduceAutotuneKey::generate(input, *dim)) } mod reduce_ops { @@ -96,12 +96,12 @@ mod reduce_ops { use super::*; - pub(crate) fn reduce_input_gen( - _key: &JitAutotuneKey, - input: &JitTensor, - output: &JitTensor, + pub(crate) fn reduce_input_gen( + _key: &CubeAutotuneKey, + input: &CubeTensor, + output: &CubeTensor, dim: &usize, - ) -> (JitTensor, JitTensor, usize) { + ) -> (CubeTensor, CubeTensor, usize) { let random_bounds: (In, In) = ((-10.0_f32).elem::(), (10.0_f32).elem::()); let input = random_like_uniform(input, random_bounds.0, random_bounds.1); @@ -115,13 +115,13 @@ mod reduce_ops { } pub(crate) fn reduce< - Run: JitRuntime, - In: JitElement, - Out: JitElement, + Run: CubeRuntime, + In: CubeElement, + Out: CubeElement, Rd: cubecl::reduce::Reduce, >( - input: JitTensor, - output: JitTensor, + input: CubeTensor, + output: CubeTensor, axis: usize, ) -> Result<(), String> { cubecl::reduce::reduce::( @@ -138,13 +138,13 @@ mod reduce_ops { } pub(crate) fn reduce_shared< - Run: JitRuntime, - In: JitElement, - Out: JitElement, + Run: CubeRuntime, + In: CubeElement, + Out: CubeElement, Rd: cubecl::reduce::Reduce, >( - input: JitTensor, - output: JitTensor, + input: CubeTensor, + output: CubeTensor, axis: usize, ) -> Result<(), String> { cubecl::reduce::reduce::( @@ -161,13 +161,13 @@ mod reduce_ops { } pub(crate) fn reduce_plane< - Run: JitRuntime, - In: JitElement, - Out: JitElement, + Run: CubeRuntime, + In: CubeElement, + Out: CubeElement, Rd: cubecl::reduce::Reduce, >( - input: JitTensor, - output: JitTensor, + input: CubeTensor, + output: CubeTensor, axis: usize, ) -> Result<(), String> { cubecl::reduce::reduce::( @@ -184,13 +184,13 @@ mod reduce_ops { } pub(crate) fn reduce_shared_plane< - Run: JitRuntime, - In: JitElement, - Out: JitElement, + Run: CubeRuntime, + In: CubeElement, + Out: CubeElement, Rd: cubecl::reduce::Reduce, >( - input: JitTensor, - output: JitTensor, + input: CubeTensor, + output: CubeTensor, axis: usize, ) -> Result<(), String> { cubecl::reduce::reduce::( @@ -209,13 +209,13 @@ mod reduce_ops { /// Executes autotune on reduce operations. #[cfg(feature = "autotune")] -pub fn autotune_sum( +pub fn autotune_sum( client: &ComputeClient, - input: JitTensor, -) -> JitTensor { + input: CubeTensor, +) -> CubeTensor { use sum_ops::*; - static TUNER: LocalTuner = local_tuner!(); + static TUNER: LocalTuner = local_tuner!(); let tunables = TunableSet::new(create_key_sum::, sum_input_gen::) .with_tunable(sum_one_shot::) @@ -228,15 +228,15 @@ pub fn autotune_sum( .with_tunable(sum_chained::); TUNER.execute( - &JitTuneId::new::(&input.device), + &CubeTuneId::new::(&input.device), client, &tunables, input, ) } -pub(crate) fn create_key_sum(input: &JitTensor) -> JitAutotuneKey { - JitAutotuneKey::Sum(SumAutotuneKey::generate(input)) +pub(crate) fn create_key_sum(input: &CubeTensor) -> CubeAutotuneKey { + CubeAutotuneKey::Sum(SumAutotuneKey::generate(input)) } #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] @@ -248,7 +248,7 @@ pub struct SumAutotuneKey { } impl SumAutotuneKey { - pub(crate) fn generate(input: &JitTensor) -> Self { + pub(crate) fn generate(input: &CubeTensor) -> Self { let dtype = input.dtype; let length = input.shape.num_elements(); Self { dtype, length } @@ -261,21 +261,21 @@ mod sum_ops { use super::*; - pub(crate) fn sum_input_gen( - _key: &JitAutotuneKey, - input: &JitTensor, - ) -> JitTensor { + pub(crate) fn sum_input_gen( + _key: &CubeAutotuneKey, + input: &CubeTensor, + ) -> CubeTensor { let random_bounds: (E, E) = ((-10.0_f32).elem::(), (10.0_f32).elem::()); random_like_uniform(input, random_bounds.0, random_bounds.1) } - pub(crate) fn sum_one_shot( - input: JitTensor, - ) -> Result, String> { + pub(crate) fn sum_one_shot( + input: CubeTensor, + ) -> Result, String> { let client = input.client.clone(); let device = input.device.clone(); let handle = client.create(E::as_bytes(&[E::from_int(0)])); - let output = JitTensor::new_contiguous(client, device, [1].into(), handle, E::dtype()); + let output = CubeTensor::new_contiguous(client, device, [1].into(), handle, E::dtype()); cubecl::reduce::shared_sum::( &input.client, @@ -288,9 +288,9 @@ mod sum_ops { } #[cfg(feature = "autotune")] - pub(crate) fn sum_chained( - input: JitTensor, - ) -> Result, String> { + pub(crate) fn sum_chained( + input: CubeTensor, + ) -> Result, String> { crate::kernel::reduce::reduce::( input, crate::kernel::reduce::ReduceStrategy::Autotune, diff --git a/crates/burn-cubecl/src/kernel/unary_float.rs b/crates/burn-cubecl/src/kernel/unary_float.rs index 4664d3c0b3..22f790e0e6 100644 --- a/crates/burn-cubecl/src/kernel/unary_float.rs +++ b/crates/burn-cubecl/src/kernel/unary_float.rs @@ -1,4 +1,4 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use cubecl::{ calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, tensor_line_size_parallel, @@ -46,13 +46,13 @@ pub(crate) fn unary_float( } } -pub(crate) fn launch_unary_float(tensor: JitTensor, args: Args) -> JitTensor +pub(crate) fn launch_unary_float(tensor: CubeTensor, args: Args) -> CubeTensor where // Magic fix for lifetime, the closure is supposed to capture everything required to create the // argument. for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, - R: JitRuntime, - E: JitElement + Float, + R: CubeRuntime, + E: CubeElement + Float, O: FloatUnaryOpFamily, { let ndims = tensor.shape.num_dims(); @@ -113,9 +113,9 @@ pub(crate) mod unary_basic { use super::*; - pub(crate) fn launch(tensor: JitTensor, args: Args) -> JitTensor + pub(crate) fn launch(tensor: CubeTensor, args: Args) -> CubeTensor where - R: JitRuntime, + R: CubeRuntime, for<'a> Args: FnOnce(&'a ()) -> &'a BasicFloatUnaryKind, { execute_with_dtype!( diff --git a/crates/burn-cubecl/src/kernel/unary_int.rs b/crates/burn-cubecl/src/kernel/unary_int.rs index 17bced52d1..d9b46bc5cf 100644 --- a/crates/burn-cubecl/src/kernel/unary_int.rs +++ b/crates/burn-cubecl/src/kernel/unary_int.rs @@ -1,4 +1,4 @@ -use crate::{ops::numeric::empty_device, tensor::JitTensor, IntElement, JitRuntime}; +use crate::{ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime, IntElement}; use cubecl::{ calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, tensor_line_size_parallel, @@ -46,10 +46,10 @@ pub(crate) fn unary_int( } } -pub(crate) fn launch_unary_int(tensor: JitTensor, args: Args) -> JitTensor +pub(crate) fn launch_unary_int(tensor: CubeTensor, args: Args) -> CubeTensor where for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, - R: JitRuntime, + R: CubeRuntime, E: IntElement + Int, O: IntUnaryOpFamily, { @@ -107,9 +107,9 @@ pub(crate) mod unary_basic_int { use super::*; - pub(crate) fn launch(tensor: JitTensor, args: Args) -> JitTensor + pub(crate) fn launch(tensor: CubeTensor, args: Args) -> CubeTensor where - R: JitRuntime, + R: CubeRuntime, for<'a> Args: FnOnce(&'a ()) -> &'a BasicIntUnaryKind, I: IntElement, { diff --git a/crates/burn-cubecl/src/kernel/unary_numeric.rs b/crates/burn-cubecl/src/kernel/unary_numeric.rs index aaeadbb685..0aac28e111 100644 --- a/crates/burn-cubecl/src/kernel/unary_numeric.rs +++ b/crates/burn-cubecl/src/kernel/unary_numeric.rs @@ -1,4 +1,4 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor, CubeRuntime}; use cubecl::{ calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, tensor_line_size_parallel, @@ -46,13 +46,16 @@ pub(crate) fn unary_numeric( } } -pub(crate) fn launch_unary_numeric(tensor: JitTensor, args: Args) -> JitTensor +pub(crate) fn launch_unary_numeric( + tensor: CubeTensor, + args: Args, +) -> CubeTensor where // Magic fix for lifetime, the closure is supposed to capture everything required to create the // argument. for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, - R: JitRuntime, - E: JitElement + Numeric, + R: CubeRuntime, + E: CubeElement + Numeric, O: NumericUnaryOpFamily, { let ndims = tensor.shape.num_dims(); diff --git a/crates/burn-cubecl/src/lib.rs b/crates/burn-cubecl/src/lib.rs index ae15fb945f..17ae9983b3 100644 --- a/crates/burn-cubecl/src/lib.rs +++ b/crates/burn-cubecl/src/lib.rs @@ -20,7 +20,7 @@ pub mod element; use burn_tensor::backend::{DeviceId, DeviceOps}; use cubecl::{compute::CubeTask, Feature, Runtime}; -pub use element::{BoolElement, FloatElement, IntElement, JitElement}; +pub use element::{BoolElement, CubeElement, FloatElement, IntElement}; mod backend; @@ -30,7 +30,7 @@ pub use backend::*; pub use cubecl; mod tune_key; -pub use tune_key::JitAutotuneKey; +pub use tune_key::CubeAutotuneKey; #[cfg(any(feature = "fusion", test))] /// Module for interacting with fusion @@ -44,11 +44,11 @@ pub mod template; pub mod tests; /// Just-in-Time runtime extending the [cube runtime](Runtime). -pub trait JitRuntime: Runtime { +pub trait CubeRuntime: Runtime { /// The device that should also implement [burn_tensor::backend::DeviceOps]. - type JitDevice: burn_tensor::backend::DeviceOps; - /// The cube server with the [JitAutotuneKey]. - type JitServer: cubecl::server::ComputeServer< + type CubeDevice: burn_tensor::backend::DeviceOps; + /// The cube server with the [CubeAutotuneKey]. + type CubeServer: cubecl::server::ComputeServer< Kernel = Box>, Feature = Feature, >; @@ -56,14 +56,14 @@ pub trait JitRuntime: Runtime(device: &R::Device) -> Self { + pub fn new(device: &R::Device) -> Self { Self { device: DeviceOps::id(device), name: R::name(), @@ -71,7 +71,7 @@ impl JitTuneId { } } -impl core::fmt::Display for JitTuneId { +impl core::fmt::Display for CubeTuneId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( "device-{}-{}-{}", diff --git a/crates/burn-cubecl/src/ops/activation_ops.rs b/crates/burn-cubecl/src/ops/activation_ops.rs index eecd6849c8..87f1ec82aa 100644 --- a/crates/burn-cubecl/src/ops/activation_ops.rs +++ b/crates/burn-cubecl/src/ops/activation_ops.rs @@ -1,9 +1,9 @@ -use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{element::BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement}; use burn_tensor::ops::ActivationOps; -impl ActivationOps for JitBackend +impl ActivationOps for CubeBackend where - R: JitRuntime, + R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, diff --git a/crates/burn-cubecl/src/ops/base.rs b/crates/burn-cubecl/src/ops/base.rs index def5a4408f..f350ab450d 100644 --- a/crates/burn-cubecl/src/ops/base.rs +++ b/crates/burn-cubecl/src/ops/base.rs @@ -1,16 +1,16 @@ -use crate::{element::JitElement, kernel, tensor::JitTensor, BoolElement, JitRuntime}; +use crate::{element::CubeElement, kernel, tensor::CubeTensor, BoolElement, CubeRuntime}; use burn_tensor::{Shape, TensorData}; use cubecl::tensor_vectorization_factor; -pub(crate) fn from_data(data: TensorData, device: &R::Device) -> JitTensor { +pub(crate) fn from_data(data: TensorData, device: &R::Device) -> CubeTensor { let shape: Shape = (&data.shape).into(); let client = R::client(device); let buffer = client.create(data.as_bytes()); - JitTensor::new_contiguous(client, device.clone(), shape, buffer, data.dtype) + CubeTensor::new_contiguous(client, device.clone(), shape, buffer, data.dtype) } -pub(crate) async fn into_data(tensor: JitTensor) -> TensorData { +pub(crate) async fn into_data(tensor: CubeTensor) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; @@ -18,9 +18,9 @@ pub(crate) async fn into_data(tensor: JitTensor TensorData::new(E::from_bytes(&bytes[..actual_len]).to_vec(), tensor.shape) } -/// Read data from a `JitTensor` synchronously +/// Read data from a `CubeTensor` synchronously #[allow(unused, reason = "useful for debugging kernels")] -pub fn into_data_sync(tensor: JitTensor) -> TensorData { +pub fn into_data_sync(tensor: CubeTensor) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one(tensor.handle.binding()); @@ -28,8 +28,8 @@ pub fn into_data_sync(tensor: JitTensor) -> Ten TensorData::new(E::from_bytes(&bytes[..actual_len]).to_vec(), tensor.shape) } -pub(crate) async fn bool_into_data( - tensor: JitTensor, +pub(crate) async fn bool_into_data( + tensor: CubeTensor, ) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; @@ -43,7 +43,10 @@ pub(crate) async fn bool_into_data( ) } -pub(crate) fn to_device(tensor: JitTensor, device: &R::Device) -> JitTensor { +pub(crate) fn to_device( + tensor: CubeTensor, + device: &R::Device, +) -> CubeTensor { if &tensor.device == device { return tensor; } @@ -52,21 +55,21 @@ pub(crate) fn to_device(tensor: JitTensor, device: &R::Device) tensor.to_client(client, device.clone()) } -pub(crate) fn empty( +pub(crate) fn empty( shape: Shape, device: &R::Device, -) -> JitTensor { +) -> CubeTensor { let client = R::client(device); let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); - JitTensor::new_contiguous(client, device.clone(), shape, buffer, E::dtype()) + CubeTensor::new_contiguous(client, device.clone(), shape, buffer, E::dtype()) } -pub(crate) fn swap_dims( - mut tensor: JitTensor, +pub(crate) fn swap_dims( + mut tensor: CubeTensor, dim1: usize, dim2: usize, -) -> JitTensor { +) -> CubeTensor { tensor.strides.swap(dim1, dim2); tensor.shape.dims.swap(dim1, dim2); @@ -74,7 +77,7 @@ pub(crate) fn swap_dims( } /// Permute a tensor's dimensions -pub fn permute(mut tensor: JitTensor, axes: &[usize]) -> JitTensor { +pub fn permute(mut tensor: CubeTensor, axes: &[usize]) -> CubeTensor { // remap strides tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect(); @@ -83,7 +86,7 @@ pub fn permute(mut tensor: JitTensor, axes: &[usize]) -> JitTe tensor } -pub(crate) fn expand(tensor: JitTensor, target_shape: Shape) -> JitTensor { +pub(crate) fn expand(tensor: CubeTensor, target_shape: Shape) -> CubeTensor { let ndims_in = tensor.shape.num_dims(); let ndims_out = target_shape.num_dims(); @@ -123,7 +126,7 @@ pub(crate) fn expand(tensor: JitTensor, target_shape: Shape) - } } - JitTensor { + CubeTensor { client: tensor.client, device: tensor.device, shape: target_shape, @@ -134,11 +137,11 @@ pub(crate) fn expand(tensor: JitTensor, target_shape: Shape) - } /// Reshape a jit tensor to a new shape -pub fn reshape(tensor: JitTensor, shape: Shape) -> JitTensor { +pub fn reshape(tensor: CubeTensor, shape: Shape) -> CubeTensor { // TODO: Not force standard layout all the time (improve performance). let tensor = kernel::into_contiguous(tensor); - JitTensor::new_contiguous( + CubeTensor::new_contiguous( tensor.client, tensor.device, shape, @@ -147,7 +150,7 @@ pub fn reshape(tensor: JitTensor, shape: Shape) -> JitTensor(tensor: &JitTensor) -> u8 { +pub(crate) fn max_vectorization(tensor: &CubeTensor) -> u8 { tensor_vectorization_factor( R::supported_line_sizes(), &tensor.shape.dims, diff --git a/crates/burn-cubecl/src/ops/bool_ops.rs b/crates/burn-cubecl/src/ops/bool_ops.rs index 4342be6421..a29ffefc8e 100644 --- a/crates/burn-cubecl/src/ops/bool_ops.rs +++ b/crates/burn-cubecl/src/ops/bool_ops.rs @@ -1,13 +1,13 @@ -use crate::{element::BoolElement, kernel, FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{element::BoolElement, kernel, CubeBackend, CubeRuntime, FloatElement, IntElement}; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor}; use burn_tensor::{ops::BoolTensorOps, Shape, TensorData}; use std::ops::Range; use super::{expand, permute}; -impl BoolTensorOps for JitBackend +impl BoolTensorOps for CubeBackend where - R: JitRuntime, + R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, diff --git a/crates/burn-cubecl/src/ops/float_ops.rs b/crates/burn-cubecl/src/ops/float_ops.rs index 1f2060f104..43139ce033 100644 --- a/crates/burn-cubecl/src/ops/float_ops.rs +++ b/crates/burn-cubecl/src/ops/float_ops.rs @@ -8,8 +8,8 @@ use crate::{ element::BoolElement, kernel::matmul::{matmul, MatmulStrategy}, }; -use crate::{execute_with_dtype, JitBackend}; -use crate::{FloatElement, IntElement, JitRuntime}; +use crate::{execute_with_dtype, CubeBackend}; +use crate::{CubeRuntime, FloatElement, IntElement}; use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor}; use burn_tensor::{ops::FloatTensorOps, Distribution, Shape, TensorData}; use burn_tensor::{DType, ElementConversion, FloatDType}; @@ -17,9 +17,9 @@ use cubecl::prelude::*; use half::{bf16, f16}; use std::ops::Range; -impl FloatTensorOps for JitBackend +impl FloatTensorOps for CubeBackend where - R: JitRuntime, + R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, diff --git a/crates/burn-cubecl/src/ops/int_ops.rs b/crates/burn-cubecl/src/ops/int_ops.rs index f90459b21b..acf4bc8f3d 100644 --- a/crates/burn-cubecl/src/ops/int_ops.rs +++ b/crates/burn-cubecl/src/ops/int_ops.rs @@ -9,7 +9,7 @@ use crate::{ element::BoolElement, kernel::prng::{random_bernoulli, random_normal, random_uniform}, }; -use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{kernel, CubeBackend, CubeRuntime, FloatElement, IntElement}; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use burn_tensor::DType; use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Shape, TensorData}; @@ -17,9 +17,9 @@ use cubecl::frontend::Numeric; use cubecl::prelude::*; use std::ops::Range; -impl IntTensorOps for JitBackend +impl IntTensorOps for CubeBackend where - R: JitRuntime, + R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, diff --git a/crates/burn-cubecl/src/ops/module_ops.rs b/crates/burn-cubecl/src/ops/module_ops.rs index c7f7b18b32..cc2302c37f 100644 --- a/crates/burn-cubecl/src/ops/module_ops.rs +++ b/crates/burn-cubecl/src/ops/module_ops.rs @@ -4,7 +4,7 @@ use crate::{ self, conv::{Conv2dStrategy, ConvTranspose2dStrategy}, }, - FloatElement, IntElement, JitBackend, JitRuntime, + CubeBackend, CubeRuntime, FloatElement, IntElement, }; use burn_tensor::ops::{ ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions, @@ -12,9 +12,9 @@ use burn_tensor::ops::{ }; use burn_tensor::ops::{FloatTensor, IntTensor}; -impl ModuleOps for JitBackend +impl ModuleOps for CubeBackend where - R: JitRuntime, + R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, diff --git a/crates/burn-cubecl/src/ops/numeric.rs b/crates/burn-cubecl/src/ops/numeric.rs index 432276ccb6..6cea803cc7 100644 --- a/crates/burn-cubecl/src/ops/numeric.rs +++ b/crates/burn-cubecl/src/ops/numeric.rs @@ -2,31 +2,31 @@ use crate::kernel::{ launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int, AddOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp, }; -use crate::{element::JitElement, tensor::JitTensor}; -use crate::{FloatElement, IntElement, JitRuntime}; +use crate::{element::CubeElement, tensor::CubeTensor}; +use crate::{CubeRuntime, FloatElement, IntElement}; use burn_tensor::{ElementConversion, Shape}; use cubecl::client::ComputeClient; use cubecl::tensor_vectorization_factor; use cubecl::{calculate_cube_count_elemwise, prelude::*}; /// Create a tensor filled with `value` -pub fn full( +pub fn full( shape: Shape, device: &R::Device, value: E, -) -> JitTensor { +) -> CubeTensor { let client = R::client(device); full_device::(client, shape, device.clone(), value) } /// Create a tensor filled with `value` -pub fn full_device( +pub fn full_device( client: ComputeClient, shape: Shape, device: R::Device, value: E, -) -> JitTensor { +) -> CubeTensor { let ndims = shape.num_dims(); let empty = empty_device::(client, device, shape); @@ -59,141 +59,168 @@ pub fn full_device( } /// Create a tensor filled with zeros -pub fn zeros(shape: Shape, device: &R::Device) -> JitTensor { +pub fn zeros(shape: Shape, device: &R::Device) -> CubeTensor { let client = R::client(device); zeros_device::(client, device.clone(), shape) } /// Create a tensor filled with zeros -pub fn zeros_device( +pub fn zeros_device( client: ComputeClient, device: R::Device, shape: Shape, -) -> JitTensor { +) -> CubeTensor { full_device::(client, shape, device, 0.elem()) } /// Create a tensor filled with ones -pub fn ones(shape: Shape, device: &R::Device) -> JitTensor { +pub fn ones(shape: Shape, device: &R::Device) -> CubeTensor { let client = R::client(device); ones_device::(client, device.clone(), shape) } /// Create a tensor filled with ones -pub fn ones_device( +pub fn ones_device( client: ComputeClient, device: R::Device, shape: Shape, -) -> JitTensor { +) -> CubeTensor { full_device::(client, shape, device, 1.elem()) } /// Create a tensor with uninitialized memory -pub fn empty_device( +pub fn empty_device( client: ComputeClient, device: R::Device, shape: Shape, -) -> JitTensor { +) -> CubeTensor { let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); - JitTensor::new_contiguous(client, device, shape, buffer, E::dtype()) + CubeTensor::new_contiguous(client, device, shape, buffer, E::dtype()) } /// Add two tensors -pub fn add(lhs: JitTensor, rhs: JitTensor) -> JitTensor { +pub fn add( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_binop::(lhs, rhs) } /// Add a tensor and a scalar -pub fn add_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn add_scalar(lhs: CubeTensor, rhs: E) -> CubeTensor { launch_scalar_binop::(lhs, rhs) } /// Subtract two tensors -pub fn sub(lhs: JitTensor, rhs: JitTensor) -> JitTensor { +pub fn sub( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_binop::(lhs, rhs) } /// Subtract a tensor and a scalar -pub fn sub_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn sub_scalar(lhs: CubeTensor, rhs: E) -> CubeTensor { launch_scalar_binop::(lhs, rhs) } /// Multiply two tensors -pub fn mul(lhs: JitTensor, rhs: JitTensor) -> JitTensor { +pub fn mul( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_binop::(lhs, rhs) } /// Multiply a tensor and a scalar -pub fn mul_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn mul_scalar(lhs: CubeTensor, rhs: E) -> CubeTensor { launch_scalar_binop::(lhs, rhs) } /// Divide two tensors -pub fn div(lhs: JitTensor, rhs: JitTensor) -> JitTensor { +pub fn div( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_binop::(lhs, rhs) } /// Divide a tensor by a scalar -pub fn div_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn div_scalar(lhs: CubeTensor, rhs: E) -> CubeTensor { launch_scalar_binop::(lhs, rhs) } /// Calculate remainder of two tensors -pub fn remainder( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn remainder( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_binop::(lhs, rhs) } /// Calculate the remainder of a tensor with a scalar -pub fn remainder_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn remainder_scalar( + lhs: CubeTensor, + rhs: E, +) -> CubeTensor { launch_scalar_binop::(lhs, rhs) } /// Calculate the power of two tensors -pub fn pow(lhs: JitTensor, rhs: JitTensor) -> JitTensor { +pub fn pow( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_binop::>(lhs, rhs) } /// Bitwise and two tensors -pub fn bitwise_and( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn bitwise_and( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_binop_int::(lhs, rhs) } /// Bitwise and with a scalar -pub fn bitwise_and_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn bitwise_and_scalar( + lhs: CubeTensor, + rhs: E, +) -> CubeTensor { launch_scalar_binop_int::(lhs, rhs) } /// Bitwise or two tensors -pub fn bitwise_or( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn bitwise_or( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_binop_int::(lhs, rhs) } /// Bitwise or with a scalar -pub fn bitwise_or_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn bitwise_or_scalar( + lhs: CubeTensor, + rhs: E, +) -> CubeTensor { launch_scalar_binop_int::(lhs, rhs) } /// Bitwise xor two tensors -pub fn bitwise_xor( - lhs: JitTensor, - rhs: JitTensor, -) -> JitTensor { +pub fn bitwise_xor( + lhs: CubeTensor, + rhs: CubeTensor, +) -> CubeTensor { launch_binop_int::(lhs, rhs) } /// Bitwise xor with a scalar -pub fn bitwise_xor_scalar(lhs: JitTensor, rhs: E) -> JitTensor { +pub fn bitwise_xor_scalar( + lhs: CubeTensor, + rhs: E, +) -> CubeTensor { launch_scalar_binop_int::(lhs, rhs) } diff --git a/crates/burn-cubecl/src/ops/qtensor.rs b/crates/burn-cubecl/src/ops/qtensor.rs index d6bfe7499d..b61a46195e 100644 --- a/crates/burn-cubecl/src/ops/qtensor.rs +++ b/crates/burn-cubecl/src/ops/qtensor.rs @@ -7,21 +7,21 @@ use burn_tensor::{ }; use crate::{ - element::BoolElement, kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, - JitRuntime, + element::BoolElement, kernel, tensor::CubeTensor, CubeBackend, CubeRuntime, FloatElement, + IntElement, }; /// Create a quantized tensor with packed values (u32). -fn new_qtensor>( +fn new_qtensor>( data: &[u8], shape: S, scheme: QuantizationScheme, device: &R::Device, -) -> JitTensor { +) -> CubeTensor { let client = R::client(device); let buffer = client.create(data); - JitTensor::new_contiguous( + CubeTensor::new_contiguous( client, device.clone(), shape.into(), @@ -30,9 +30,9 @@ fn new_qtensor>( ) } -impl QTensorOps for JitBackend +impl QTensorOps for CubeBackend where - R: JitRuntime, + R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, diff --git a/crates/burn-cubecl/src/ops/transaction.rs b/crates/burn-cubecl/src/ops/transaction.rs index 3a4c712756..b8001051be 100644 --- a/crates/burn-cubecl/src/ops/transaction.rs +++ b/crates/burn-cubecl/src/ops/transaction.rs @@ -3,11 +3,11 @@ use burn_tensor::{ DType, TensorData, }; -use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{element::BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement}; -impl TransactionOps for JitBackend +impl TransactionOps for CubeBackend where - R: JitRuntime, + R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, diff --git a/crates/burn-cubecl/src/template/base.rs b/crates/burn-cubecl/src/template/base.rs index 5dd40bf14b..514b84cb00 100644 --- a/crates/burn-cubecl/src/template/base.rs +++ b/crates/burn-cubecl/src/template/base.rs @@ -1,4 +1,4 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use crate::{element::CubeElement, tensor::CubeTensor, CubeRuntime}; use burn_common::ExecutionMode; use cubecl::{prelude::*, Compiler, KernelId}; @@ -76,7 +76,7 @@ macro_rules! kernel_source { /// | (D + 1)..(2 * D + 1) | rhs strides | /// | (2 * D + 1)..(3 * D + 1) | lhs shape | /// | (3 * D + 1)..(4 * D + 1) | rhs shape | -pub fn build_info(tensors: &[&JitTensor]) -> Vec { +pub fn build_info(tensors: &[&CubeTensor]) -> Vec { let ndims = tensors[0].shape.num_dims(); let mut info: Vec = vec![0; tensors.len() * 2 * ndims + 1]; info[0] = ndims as u32; diff --git a/crates/burn-cubecl/src/tensor/base.rs b/crates/burn-cubecl/src/tensor/base.rs index 7b72073c06..4835b1ffe2 100644 --- a/crates/burn-cubecl/src/tensor/base.rs +++ b/crates/burn-cubecl/src/tensor/base.rs @@ -1,6 +1,6 @@ -use crate::element::JitElement; +use crate::element::CubeElement; use crate::kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}; -use crate::JitRuntime; +use crate::CubeRuntime; use burn_tensor::quantization::QTensorPrimitive; use burn_tensor::{DType, Shape, TensorMetadata}; use cubecl::client::ComputeClient; @@ -12,8 +12,8 @@ use std::marker::PhantomData; /// The basic tensor primitive struct. #[derive(new)] -pub struct JitTensor { - /// Compute client for the [runtime](JitRuntime). +pub struct CubeTensor { + /// Compute client for the [runtime](CubeRuntime). pub client: ComputeClient, /// The buffer where the data are stored. pub handle: Handle, @@ -27,19 +27,19 @@ pub struct JitTensor { pub dtype: DType, } -impl From> for TensorHandle { - fn from(val: JitTensor) -> Self { +impl From> for TensorHandle { + fn from(val: CubeTensor) -> Self { TensorHandle::new(val.shape.dims.to_vec(), val.strides.to_vec(), val.handle) } } -impl core::fmt::Debug for JitTensor +impl core::fmt::Debug for CubeTensor where - R: JitRuntime, + R: CubeRuntime, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( - "JitTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}", + "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}", self.shape, self.device, self.strides, @@ -49,9 +49,9 @@ where } } -impl Clone for JitTensor +impl Clone for CubeTensor where - R: JitRuntime, + R: CubeRuntime, { fn clone(&self) -> Self { Self { @@ -65,7 +65,7 @@ where } } -impl TensorMetadata for JitTensor { +impl TensorMetadata for CubeTensor { fn dtype(&self) -> DType { self.dtype } @@ -75,7 +75,7 @@ impl TensorMetadata for JitTensor { } } -impl QTensorPrimitive for JitTensor { +impl QTensorPrimitive for CubeTensor { fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme { if let DType::QFloat(scheme) = &self.dtype { scheme @@ -191,9 +191,9 @@ macro_rules! execute_with_dtype { }}; } -impl JitTensor +impl CubeTensor where - R: JitRuntime, + R: CubeRuntime, { /// Create a new tensor with a contiguous memory layout. pub fn new_contiguous( @@ -270,7 +270,7 @@ where } /// Return the reference to a tensor argument. - pub fn as_tensor_arg<'a, E: JitElement>(&'a self, vectorisation: u8) -> TensorArg<'a, R> { + pub fn as_tensor_arg<'a, E: CubeElement>(&'a self, vectorisation: u8) -> TensorArg<'a, R> { let handle: TensorHandleRef<'a, R> = self.as_handle_ref(); unsafe { @@ -284,7 +284,7 @@ where } /// Return the reference to an array argument. - pub fn as_array_arg(&self, vectorisation: u8) -> ArrayArg<'_, R> { + pub fn as_array_arg(&self, vectorisation: u8) -> ArrayArg<'_, R> { unsafe { ArrayArg::from_raw_parts::( &self.handle, diff --git a/crates/burn-cubecl/src/tests/mod.rs b/crates/burn-cubecl/src/tests/mod.rs index 022a31babe..c59602a17d 100644 --- a/crates/burn-cubecl/src/tests/mod.rs +++ b/crates/burn-cubecl/src/tests/mod.rs @@ -42,7 +42,7 @@ macro_rules! testgen_all { $crate::testgen_all!([Float], [Int], [Bool]); }; ([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => { - mod jit { + mod cube { burn_cubecl::testgen_jit!([$($float),*], [$($int),*], [$($bool),*]); mod kernel { @@ -84,7 +84,7 @@ macro_rules! testgen_all { burn_cubecl::testgen_quantization!(); } } - mod jit_fusion { + mod cube_fusion { burn_cubecl::testgen_jit_fusion!([$($float),*], [$($int),*], [$($bool),*]); } }; @@ -100,8 +100,8 @@ macro_rules! testgen_jit { pub use super::*; use burn_cubecl::tests::{burn_autodiff, burn_ndarray, burn_tensor, serial_test}; - pub type TestBackend = JitBackend; - pub type TestBackend2 = JitBackend; + pub type TestBackend = CubeBackend; + pub type TestBackend2 = CubeBackend; pub type ReferenceBackend = burn_ndarray::NdArray; pub type TestTensor = burn_tensor::Tensor; @@ -140,8 +140,8 @@ macro_rules! testgen_jit_fusion { use super::*; use burn_cubecl::tests::{burn_autodiff, burn_fusion, burn_ndarray, burn_tensor}; - pub type TestBackend = burn_fusion::Fusion>; - pub type TestBackend2 = burn_fusion::Fusion>; + pub type TestBackend = burn_fusion::Fusion>; + pub type TestBackend2 = burn_fusion::Fusion>; pub type ReferenceBackend = burn_ndarray::NdArray; pub type TestTensor = burn_tensor::Tensor; diff --git a/crates/burn-cubecl/src/tune_key.rs b/crates/burn-cubecl/src/tune_key.rs index 9a86a85483..958d2269f0 100644 --- a/crates/burn-cubecl/src/tune_key.rs +++ b/crates/burn-cubecl/src/tune_key.rs @@ -9,7 +9,7 @@ use std::fmt::Display; #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)] /// Key for all autotune-enabled operations -pub enum JitAutotuneKey { +pub enum CubeAutotuneKey { /// Key for matmul operation Matmul(MatmulAutotuneKey), /// Key for reduce dim operations @@ -22,16 +22,16 @@ pub enum JitAutotuneKey { ConvTranspose2d(ConvTranspose2dAutotuneKey), } -impl Display for JitAutotuneKey { +impl Display for CubeAutotuneKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - JitAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), - JitAutotuneKey::Reduce(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), - JitAutotuneKey::Sum(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), - JitAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), - JitAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), + CubeAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), + CubeAutotuneKey::Reduce(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + CubeAutotuneKey::Sum(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + CubeAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), + CubeAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), } } } -impl AutotuneKey for JitAutotuneKey {} +impl AutotuneKey for CubeAutotuneKey {} diff --git a/crates/burn-cuda/src/lib.rs b/crates/burn-cuda/src/lib.rs index d434865bb6..bce9de990e 100644 --- a/crates/burn-cuda/src/lib.rs +++ b/crates/burn-cuda/src/lib.rs @@ -2,19 +2,19 @@ extern crate alloc; -use burn_cubecl::JitBackend; +use burn_cubecl::CubeBackend; pub use cubecl::cuda::CudaDevice; use cubecl::cuda::CudaRuntime; #[cfg(not(feature = "fusion"))] -pub type Cuda = JitBackend; +pub type Cuda = CubeBackend; #[cfg(feature = "fusion")] -pub type Cuda = burn_fusion::Fusion>; +pub type Cuda = burn_fusion::Fusion>; #[cfg(test)] mod tests { - use burn_cubecl::JitBackend; + use burn_cubecl::CubeBackend; pub type TestRuntime = cubecl::cuda::CudaRuntime; pub use half::f16; diff --git a/crates/burn-hip/src/lib.rs b/crates/burn-hip/src/lib.rs index 8d11255b7c..530279db7d 100644 --- a/crates/burn-hip/src/lib.rs +++ b/crates/burn-hip/src/lib.rs @@ -2,7 +2,7 @@ extern crate alloc; #[cfg(target_os = "linux")] -use burn_cubecl::JitBackend; +use burn_cubecl::CubeBackend; #[cfg(target_os = "linux")] pub use cubecl::hip::HipDevice; @@ -12,18 +12,18 @@ use cubecl::hip::HipRuntime; #[cfg(target_os = "linux")] #[cfg(not(feature = "fusion"))] -pub type Hip = JitBackend; +pub type Hip = CubeBackend; #[cfg(target_os = "linux")] #[cfg(feature = "fusion")] -pub type Hip = burn_fusion::Fusion>; +pub type Hip = burn_fusion::Fusion>; // TODO: Hang the computer when AMD isn't available. // // #[cfg(target_os = "linux")] // #[cfg(test)] // mod tests { -// use burn_cubecl::JitBackend; +// use burn_cubecl::CubeBackend; // // pub type TestRuntime = cubecl::hip::HipRuntime; // pub use half::f16; diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index c382065a32..51965b3f6b 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -17,18 +17,18 @@ version.workspace = true [features] candle = ["burn-candle"] -default = ["ndarray", "jit-backend", "fusion"] +default = ["ndarray", "cubecl-backend", "fusion"] export-tests = ["burn-tensor-testgen"] fusion = ["burn-fusion", "burn-cuda/fusion", "burn-wgpu/fusion"] -jit-backend = ["cubecl", "burn-cubecl"] +cubecl-backend = ["cubecl", "burn-cubecl"] ndarray = ["burn-ndarray"] tch = ["burn-tch"] # Test features test-cpu = ["export-tests"] -test-cuda = ["jit-backend", "export-tests"] -test-vulkan = ["burn-wgpu/vulkan", "jit-backend", "export-tests"] -test-wgpu = ["jit-backend", "export-tests"] +test-cuda = ["cubecl-backend", "export-tests"] +test-vulkan = ["burn-wgpu/vulkan", "cubecl-backend", "export-tests"] +test-wgpu = ["cubecl-backend", "export-tests"] [dependencies] burn-candle = { path = "../burn-candle", version = "0.17.0", optional = true } diff --git a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs similarity index 97% rename from crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs rename to crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs index 6d089c58ec..252c40568b 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs +++ b/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs @@ -4,14 +4,14 @@ //! DASIP, 2018 use crate::{ - backends::jit::connected_components::stats_from_opts, ConnectedStatsOptions, + backends::cube::connected_components::stats_from_opts, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, }; use burn_cubecl::{ kernel, ops::{into_data_sync, numeric::zeros_device}, - tensor::JitTensor, - BoolElement, FloatElement, IntElement, JitBackend, JitRuntime, + tensor::CubeTensor, + BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement, }; use burn_tensor::{ops::IntTensorOps, Shape}; use cubecl::{prelude::*, Feature}; @@ -464,14 +464,14 @@ fn compact_stats( } #[allow(clippy::type_complexity)] -pub fn hardware_accelerated( - img: JitTensor, +pub fn hardware_accelerated( + img: CubeTensor, stats_opt: ConnectedStatsOptions, connectivity: Connectivity, ) -> Result< ( - JitTensor, - ConnectedStatsPrimitive>, + CubeTensor, + ConnectedStatsPrimitive>, ), String, > { @@ -566,7 +566,7 @@ pub fn hardware_accelerated::int_max(stats.max_label); + let max_label = CubeBackend::::int_max(stats.max_label); let max_label = into_data_sync::(max_label).convert::(); let max_label = max_label.as_slice::().unwrap()[0] as usize; let sliced = kernel::slice::( diff --git a/crates/burn-vision/src/backends/jit/connected_components/mod.rs b/crates/burn-vision/src/backends/cube/connected_components/mod.rs similarity index 87% rename from crates/burn-vision/src/backends/jit/connected_components/mod.rs rename to crates/burn-vision/src/backends/cube/connected_components/mod.rs index 95b0d8faec..152c74d0aa 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/mod.rs +++ b/crates/burn-vision/src/backends/cube/connected_components/mod.rs @@ -6,8 +6,8 @@ mod prefix_sum; use burn_cubecl::{ ops::numeric::{full_device, zeros_device}, - tensor::JitTensor, - BoolElement, FloatElement, IntElement, JitBackend, JitRuntime, + tensor::CubeTensor, + BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement, }; use burn_tensor::Shape; pub use hardware_accelerated::*; @@ -15,11 +15,11 @@ pub use hardware_accelerated::*; use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive}; pub(crate) fn stats_from_opts( - l: JitTensor, + l: CubeTensor, opts: ConnectedStatsOptions, -) -> ConnectedStatsPrimitive> +) -> ConnectedStatsPrimitive> where - R: JitRuntime, + R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, @@ -30,7 +30,7 @@ where let max = I::max_value(); let max = || full_device::(l.client.clone(), shape.clone(), l.device.clone(), max); let dummy = || { - JitTensor::new_contiguous( + CubeTensor::new_contiguous( l.client.clone(), l.device.clone(), shape.clone(), diff --git a/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs b/crates/burn-vision/src/backends/cube/connected_components/prefix_sum.rs similarity index 98% rename from crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs rename to crates/burn-vision/src/backends/cube/connected_components/prefix_sum.rs index ae81fd9ad0..f35923481b 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs +++ b/crates/burn-vision/src/backends/cube/connected_components/prefix_sum.rs @@ -6,8 +6,8 @@ use burn_cubecl::{ numeric::{empty_device, zeros_device}, reshape, }, - tensor::JitTensor, - IntElement, JitRuntime, + tensor::CubeTensor, + CubeRuntime, IntElement, }; const CUBE_SIZE: u32 = 256; @@ -214,7 +214,7 @@ fn count_trailing_zeros(num: u32) -> u32 { } /// Compute the prefix sum of a tensor -pub fn prefix_sum(input: JitTensor) -> JitTensor { +pub fn prefix_sum(input: CubeTensor) -> CubeTensor { let client = input.client.clone(); let device = input.device.clone(); let num_elems = input.shape.num_elements() as u32; diff --git a/crates/burn-vision/src/backends/jit/mod.rs b/crates/burn-vision/src/backends/cube/mod.rs similarity index 100% rename from crates/burn-vision/src/backends/jit/mod.rs rename to crates/burn-vision/src/backends/cube/mod.rs diff --git a/crates/burn-vision/src/backends/jit/ops.rs b/crates/burn-vision/src/backends/cube/ops.rs similarity index 97% rename from crates/burn-vision/src/backends/jit/ops.rs rename to crates/burn-vision/src/backends/cube/ops.rs index e741d5dfd0..50eda8205a 100644 --- a/crates/burn-vision/src/backends/jit/ops.rs +++ b/crates/burn-vision/src/backends/cube/ops.rs @@ -1,11 +1,11 @@ use crate::{ backends::cpu, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, VisionOps, }; +use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement}; #[cfg(feature = "fusion")] use burn_fusion::{client::FusionClient, stream::Operation, Fusion, FusionBackend, FusionRuntime}; #[cfg(feature = "fusion")] use burn_ir::{CustomOpIr, HandleContainer, OperationIr}; -use burn_cubecl::{BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_tensor::{ ops::{BoolTensor, IntTensor}, Element, @@ -13,9 +13,9 @@ use burn_tensor::{ use super::connected_components::hardware_accelerated; -impl VisionOps for JitBackend +impl VisionOps for CubeBackend where - R: JitRuntime, + R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, diff --git a/crates/burn-vision/src/backends/mod.rs b/crates/burn-vision/src/backends/mod.rs index 6886bb4907..46b180864c 100644 --- a/crates/burn-vision/src/backends/mod.rs +++ b/crates/burn-vision/src/backends/mod.rs @@ -1,3 +1,3 @@ pub(crate) mod cpu; -#[cfg(feature = "jit-backend")] -mod jit; +#[cfg(feature = "cubecl-backend")] +mod cube; diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 8b3fc5d4d0..d653d62670 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -9,7 +9,7 @@ pub use burn_cubecl::{ template::{build_info, KernelSource, SourceKernel, SourceTemplate}, }; -pub use burn_cubecl::{tensor::JitTensor, JitBackend}; +pub use burn_cubecl::{tensor::CubeTensor, CubeBackend}; pub use burn_cubecl::{BoolElement, FloatElement, IntElement}; pub use cubecl::flex32; pub use cubecl::CubeDim; @@ -61,7 +61,7 @@ pub use cubecl::wgpu::WgslCompiler; /// You can disable the `fusion` feature flag to remove that functionality, which might be /// necessary on `wasm` for now. pub type Wgpu = - burn_fusion::Fusion>; + burn_fusion::Fusion>; #[cfg(not(feature = "fusion"))] /// Tensor backend that uses the wgpu crate for executing GPU compute shaders. @@ -95,7 +95,7 @@ pub type Wgpu = /// /// You can enable the `fusion` feature flag to add that functionality, which might improve /// performance. -pub type Wgpu = JitBackend; +pub type Wgpu = CubeBackend; #[cfg(feature = "vulkan")] /// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V. @@ -107,7 +107,7 @@ pub type WebGpu = Wgpu; #[cfg(test)] mod tests { - use burn_cubecl::JitBackend; + use burn_cubecl::CubeBackend; #[cfg(feature = "vulkan")] pub use half::f16; diff --git a/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs b/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs index de6bfcc7d4..a417e32e59 100644 --- a/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs +++ b/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs @@ -71,7 +71,7 @@ fn autodiff(device: &B::Device) { } fn main() { - type MyBackend = burn::backend::wgpu::JitBackend; + type MyBackend = burn::backend::wgpu::CubeBackend; type MyAutodiffBackend = burn::backend::Autodiff; let device = Default::default(); inference::(&device); diff --git a/examples/custom-cubecl-kernel/src/backward.rs b/examples/custom-cubecl-kernel/src/backward.rs index 708fc184c7..4cf769a729 100644 --- a/examples/custom-cubecl-kernel/src/backward.rs +++ b/examples/custom-cubecl-kernel/src/backward.rs @@ -10,10 +10,10 @@ use burn::{ }, tensor::{Shape, TensorMetadata}, }; -use burn_cubecl::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; +use burn_cubecl::{element::BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement}; -impl AutodiffBackend - for Autodiff> +impl AutodiffBackend + for Autodiff> { } diff --git a/examples/custom-cubecl-kernel/src/forward.rs b/examples/custom-cubecl-kernel/src/forward.rs index 5156d44fba..23b8ad157b 100644 --- a/examples/custom-cubecl-kernel/src/forward.rs +++ b/examples/custom-cubecl-kernel/src/forward.rs @@ -3,14 +3,14 @@ use crate::{kernel::fused_matmul_add_relu_kernel, FloatTensor}; use super::Backend; use burn::tensor::Shape; use burn_cubecl::{ - element::BoolElement, kernel::into_contiguous, tensor::JitTensor, FloatElement, IntElement, - JitBackend, JitRuntime, + element::BoolElement, kernel::into_contiguous, tensor::CubeTensor, CubeBackend, CubeRuntime, + FloatElement, IntElement, }; use cubecl::{CubeCount, CubeDim}; -/// Implement our custom backend trait for the generic `JitBackend`. -impl Backend - for JitBackend +/// Implement our custom backend trait for the generic `CubeBackend`. +impl Backend + for CubeBackend { fn fused_matmul_add_relu( lhs: FloatTensor, @@ -50,7 +50,7 @@ impl Backend .empty(shape_out.num_elements() * core::mem::size_of::()); // Create the output tensor primitive. - let output = JitTensor::new_contiguous( + let output = CubeTensor::new_contiguous( lhs.client.clone(), lhs.device.clone(), shape_out, diff --git a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs index 0c2201080e..168267a33c 100644 --- a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs +++ b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs @@ -71,7 +71,7 @@ fn autodiff(device: &B::Device) { } fn main() { - type MyBackend = burn::backend::wgpu::JitBackend; + type MyBackend = burn::backend::wgpu::CubeBackend; type MyAutodiffBackend = burn::backend::Autodiff; let device = Default::default(); inference::(&device); diff --git a/examples/custom-wgpu-kernel/src/backward.rs b/examples/custom-wgpu-kernel/src/backward.rs index eb374d6c10..f7d4d21907 100644 --- a/examples/custom-wgpu-kernel/src/backward.rs +++ b/examples/custom-wgpu-kernel/src/backward.rs @@ -9,13 +9,13 @@ use burn::{ ops::{broadcast_shape, Backward, Ops, OpsKind}, Autodiff, NodeID, }, - wgpu::{BoolElement, FloatElement, IntElement, JitBackend, WgpuRuntime}, + wgpu::{BoolElement, CubeBackend, FloatElement, IntElement, WgpuRuntime}, }, tensor::{Shape, TensorMetadata}, }; impl AutodiffBackend - for Autodiff> + for Autodiff> { } diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index e257e13bf0..77e6ba0369 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -3,8 +3,8 @@ use crate::FloatTensor; use super::Backend; use burn::{ backend::wgpu::{ - build_info, into_contiguous, kernel_source, BoolElement, FloatElement, IntElement, - JitBackend, JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime, + build_info, into_contiguous, kernel_source, BoolElement, CubeBackend, CubeTensor, + FloatElement, IntElement, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime, }, tensor::Shape, }; @@ -42,7 +42,7 @@ impl KernelSource for FusedMatmulAddRelu { /// Implement our custom backend trait for the existing backend `WgpuBackend`. impl Backend - for JitBackend + for CubeBackend { fn fused_matmul_add_relu( lhs: FloatTensor, @@ -82,7 +82,7 @@ impl Backend .empty(shape_out.num_elements() * core::mem::size_of::()); // Create the output tensor primitive. - let output = JitTensor::new_contiguous( + let output = CubeTensor::new_contiguous( lhs.client.clone(), lhs.device.clone(), shape_out, diff --git a/examples/modern-lstm/README.md b/examples/modern-lstm/README.md index 832851a1f0..d26893a83e 100644 --- a/examples/modern-lstm/README.md +++ b/examples/modern-lstm/README.md @@ -21,7 +21,7 @@ the project's specific needs. ```sh # Cuda backend -cargo run --example lstm-train --release --features cuda-jit +cargo run --example lstm-train --release --features cuda # Wgpu backend cargo run --example lstm-train --release --features wgpu @@ -42,5 +42,5 @@ cargo run --example lstm-train --release --features ndarray-blas-netlib ### Inference ```sh -cargo run --example lstm-infer --release --features cuda-jit +cargo run --example lstm-infer --release --features cuda ```