From e23c8efc5bb83e412d4c80cd80ff18759ce228d2 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 12 Feb 2025 14:44:37 -0500 Subject: [PATCH] Fix stage tiling errors (#2803) * Fix stage tiling errors * Remove comment --- .../burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs | 8 ++++---- .../burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs index ea43765281..56a6ac7752 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs @@ -121,13 +121,13 @@ impl SimpleIm2col { let (tile_x, tile_y) = match config.tiling_order(ident) { TilingOrderConfig::RowMajor => RowMajorTiling::to_x_y( nth_tile, - stage_tiling.total_row(), - stage_tiling.total_col(), + stage_tiling.tile_count_row(), + stage_tiling.tile_count_col(), ), TilingOrderConfig::ColMajor => ColMajorTiling::to_x_y( nth_tile, - stage_tiling.total_row(), - stage_tiling.total_col(), + stage_tiling.tile_count_row(), + stage_tiling.tile_count_col(), ), }; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs index 19d2bebfc2..ea89b3a165 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs @@ -98,8 +98,8 @@ impl Im2colReader { #[comptime] config: G, ) -> Line { let line_size = config.global_line_size(ident); - let tile_size_x = config.stage_tiling(ident).total_row(); - let tile_size_y = config.stage_tiling(ident).total_col(); + let tile_size_x = config.stage_tiling(ident).tile_shape_row(); + let tile_size_y = config.stage_tiling(ident).tile_shape_col(); let view_tile_m = tile_x * tile_size_x + self.m_offset; let view_tile_k = tile_y * tile_size_y + self.k_offset;