Skip to content

Commit

Permalink
Fix stage tiling errors (#2803)
Browse files Browse the repository at this point in the history
* Fix stage tiling errors

* Remove comment
  • Loading branch information
laggui authored Feb 12, 2025
1 parent 8b8a08b commit e23c8ef
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ impl SimpleIm2col {
let (tile_x, tile_y) = match config.tiling_order(ident) {
TilingOrderConfig::RowMajor => RowMajorTiling::to_x_y(
nth_tile,
stage_tiling.total_row(),
stage_tiling.total_col(),
stage_tiling.tile_count_row(),
stage_tiling.tile_count_col(),
),
TilingOrderConfig::ColMajor => ColMajorTiling::to_x_y(
nth_tile,
stage_tiling.total_row(),
stage_tiling.total_col(),
stage_tiling.tile_count_row(),
stage_tiling.tile_count_col(),
),
};

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ impl<E: Numeric> Im2colReader<E> {
#[comptime] config: G,
) -> Line<E> {
let line_size = config.global_line_size(ident);
let tile_size_x = config.stage_tiling(ident).total_row();
let tile_size_y = config.stage_tiling(ident).total_col();
let tile_size_x = config.stage_tiling(ident).tile_shape_row();
let tile_size_y = config.stage_tiling(ident).tile_shape_col();

let view_tile_m = tile_x * tile_size_x + self.m_offset;
let view_tile_k = tile_y * tile_size_y + self.k_offset;
Expand Down

0 comments on commit e23c8ef

Please sign in to comment.