Skip to content

Commit

Permalink
bump cubecl version with dummy implementations (#2814)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxtremblay authored Feb 17, 2025
1 parent 37822fd commit 136eeb6
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 36 deletions.
32 changes: 16 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7f07d3969bc7d69c6ae2f87bd806dd4f18267741" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7f07d3969bc7d69c6ae2f87bd806dd4f18267741" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
26 changes: 26 additions & 0 deletions crates/burn-cubecl/src/fusion/matmul/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,32 @@ impl MatmulArgs for FusedMatmulArgs {
)
}

#[allow(unreachable_code)]
fn read_window_lhs<EG: Numeric>(
_state: &Self::State<EG>,
_start: u32,
_end: u32,
) -> Slice<Line<EG>> {
comptime!(todo!());
// TODO This is a dummy return value to satisfy the type checker
// before working on an implementation.
// Remove the allow annotation after implementing this function.
SharedMemory::new_lined(0, 0_u32).to_slice()
}

#[allow(unreachable_code)]
fn read_window_rhs<EG: Numeric>(
_state: &Self::State<EG>,
_start: u32,
_end: u32,
) -> Slice<Line<EG>> {
comptime!(todo!());
// TODO This is a dummy return value to satisfy the type checker
// before working on an implementation.
// Remove the allow annotation after implementing this function.
SharedMemory::new_lined(0, 0_u32).to_slice()
}

fn write_out<EG: Numeric>(state: &mut Self::State<EG>, coordinate: u32, value: Line<EG>) {
let mut values = Registry::<Arg, Line<EG>>::new();
let mut args = comptime![Sequence::<Arg>::new()];
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-cubecl/src/fusion/matmul/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use cubecl::linalg::matmul::components;
use cubecl::linalg::matmul::components::tile::accelerated::Accelerated;
use cubecl::linalg::matmul::components::MatmulProblem;
use cubecl::linalg::matmul::kernels::matmul::{
MatmulSelector, PipelinedSelector, SpecializedSelector, StandardSelector,
MatmulSelector, SimplePipelinedSelector, SimpleSelector, SpecializedSelector,
};
use cubecl::linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError};
use cubecl::linalg::tensor::{matrix_layout, MatrixLayout};
Expand Down Expand Up @@ -361,7 +361,7 @@ impl FusedMatmul {

match self.selector {
FusedMatmulSelector::Standard => {
match matmul_launch_kernel::<R, EG, StandardSelector<Accelerated>>(
match matmul_launch_kernel::<R, EG, SimpleSelector<Accelerated>>(
client,
FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out),
outputs,
Expand All @@ -373,7 +373,7 @@ impl FusedMatmul {
}
}
FusedMatmulSelector::Pipelined => {
match matmul_launch_kernel::<R, EG, PipelinedSelector<Accelerated>>(
match matmul_launch_kernel::<R, EG, SimplePipelinedSelector<Accelerated>>(
client,
FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out),
outputs,
Expand Down
25 changes: 13 additions & 12 deletions crates/burn-cubecl/src/kernel/conv/conv2d/gemm/homogeneous/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use cubecl::{
components::{
global::{
self,
full_load::{self, CyclicLoading, RhsLoader},
output_loader::Unloader,
single_stage::{self, loader::RhsLoader, CyclicLoading},
AccumulatorLoader, GlobalConfig, InputLoader,
},
stage::{
Expand Down Expand Up @@ -69,7 +69,7 @@ where
>,
{
type LhsLoader = SimpleIm2colLoader<CS, Self::Config>;
type Config = HomogeneousConfig<full_load::Config<SMM::Config>>;
type Config = HomogeneousConfig<single_stage::Config<SMM::Config>>;
type RhsLoader = RhsLoader<CS::EG, CS::ES, SMM::Config, CyclicLoading>;
type AccumulatorLoader = BiasLoader<CS, SMM::Config>;

Expand Down Expand Up @@ -187,7 +187,7 @@ impl<SMM> ConvolutionConfigFactory for ImplicitGemmConvolutionFamily<SMM>
where
SMM: StageMatmulFamily,
{
type Config = config::HomogeneousConfig<full_load::Config<SMM::Config>>;
type Config = config::HomogeneousConfig<single_stage::Config<SMM::Config>>;
type Input = SMM::Input;

fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> {
Expand All @@ -212,7 +212,7 @@ where
let size = SMM::stage_shape(&smm_config);

config::HomogeneousConfig::new(
full_load::Config::new(
single_stage::Config::new(
smm_config,
// TODO: Find the correct condition to avoid check bounds.
true,
Expand All @@ -224,6 +224,7 @@ where
problem.rhs_line_size as u32,
problem.out_line_size as u32,
size.k,
global::LoadMode::Coalesced,
),
(problem.out_shape_y as u32, problem.out_shape_x as u32),
problem.kernel_size,
Expand Down Expand Up @@ -360,21 +361,21 @@ pub mod config {
self.matmul.tiling_order(ident)
}

fn check_m_bounds(&self) -> bool {
self.matmul.check_m_bounds()
fn check_row_bounds(&self, ident: Ident) -> bool {
self.matmul.check_row_bounds(ident)
}

fn check_n_bounds(&self) -> bool {
self.matmul.check_n_bounds()
}

fn check_k_bounds(&self) -> bool {
self.matmul.check_k_bounds()
fn check_col_bounds(&self, ident: Ident) -> bool {
self.matmul.check_col_bounds(ident)
}

fn transpose_load(&self, ident: Ident) -> bool {
self.matmul.transpose_load(ident)
}

fn load_mode(&self) -> global::LoadMode {
self.matmul.load_mode()
}
}

impl<M: GlobalConfig> gemm::ConvGemmConfig for HomogeneousConfig<M> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use cubecl::{
},
prelude::*,
};
use cubecl::prelude::pipeline::Pipeline;
use std::marker::PhantomData;

use crate::kernel::conv::{precision::ConvPrecision, reader::im2col::Im2colReader, ConvGemmConfig};
Expand Down Expand Up @@ -39,6 +40,11 @@ impl<CS: ConvPrecision, G: ConvGemmConfig> InputLoader<CS::EG, CS::ES, G>
);
}

/// Fills the stage at the current k offset.
fn fill_stage_window(_this: &mut Self, _pipeline: Pipeline<CS::ES>, #[comptime] _config: G) {
comptime!(todo!());
}

fn advance_view(this: &mut Self, k_offset: u32) {
this.tensor_view.update_view(k_offset);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ impl<E: Numeric> Im2colReader<E> {
let x =
(out_x * config.stride(1) + kernel_x * config.dilation(1)) as i32 - config.padding(1);

let m_in_bounds = comptime!(!config.check_m_bounds()) || view_m < self.shape_m;
let k_in_bounds = comptime!(!config.check_k_bounds()) || view_k < self.shape_k;
let m_in_bounds = comptime!(!config.check_row_bounds(Ident::Lhs)) || view_m < self.shape_m;
let k_in_bounds = comptime!(!config.check_col_bounds(Ident::Lhs)) || view_k < self.shape_k;
let no_padding = comptime!(config.padding(0) == 0 && config.padding(1) == 0);
let hw_in_bounds = no_padding
|| (y >= 0 && (y as u32) < self.shape_y && x >= 0 && (x as u32) < self.shape_x);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cubecl/src/kernel/matmul/tune/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ fn matmul_accelerated<R: CubeRuntime, E: FloatElement>(
out: CubeTensor<R>,
) -> Result<(), String> {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Standard,
&Strategy::Simple,
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
Expand Down

0 comments on commit 136eeb6

Please sign in to comment.