Skip to content

Commit

Permalink
Fix/fuse select (#2804)
Browse files Browse the repository at this point in the history
* Fix

* Cleanup

* Well another fix

* Cleanup

* Remove println
  • Loading branch information
nathanielsimard authored Feb 12, 2025
1 parent e23c8ef commit 1962fb1
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 43 deletions.
2 changes: 1 addition & 1 deletion crates/burn-jit/src/fusion/elemwise/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {

impl<R: JitRuntime> OptimizationBuilder<JitOptimization<R>> for ElementWiseBuilder<R> {
fn register(&mut self, operation: &burn_ir::OperationIr) {
self.builder.register(operation)
self.builder.register(operation);
}

fn build(&self) -> JitOptimization<R> {
Expand Down
150 changes: 111 additions & 39 deletions crates/burn-jit/src/fusion/on_write/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1147,63 +1147,135 @@ fn select_indices<C: Numeric>(
_ => panic!("Indices tensor isn't an input"),
};

let stride_input = global_stride(inputs, dim, pos_input, precision_input);
let stride_input_dim = global_stride(inputs, dim, pos_input, precision_input);

let mut index = Line::empty(line_size_ref).fill(0);
let mut index = 0u32;
let mut result = Line::empty(line_size_ref);

if comptime![dim > 0] {
let index_before = global_offset(
inputs,
outputs,
write_pos,
comment!(input.clone()),
comptime![Some((0u32, dim))],
config,
);
index += Line::new(index_before);
}
if comptime![dim != config.rank - 1] {
// In this scenario the select is actually broadcasted along the axis we're working on.
//
// Therefore the same indices are used to fetch multiple entries in the input tensor.

if comptime![dim + 1 < config.rank] {
let index_after = global_offset(
let write_pos_input = write_pos * line_size_ref;
let stride_input_line = global_stride(
inputs,
outputs,
write_pos,
input,
comptime![Some((dim + 1, config.rank))],
config,
comptime![config.rank - 1],
pos_input,
precision_input,
);
index += Line::new(index_after);
}

let mut result = Line::empty(line_size_ref);
if comptime![dim > 0] {
let index_before = global_offset(
inputs,
outputs,
write_pos_input,
comment!(input.clone()),
comptime![Some((0u32, dim))],
config,
);
index += index_before;
}

#[unroll]
for i in 0..line_size_ref {
let index_indices = ((write_pos * line_size_ref) + i) / stride_dim_ref % shape_dim_ref;
if comptime![dim + 1 < config.rank] {
let index_after = global_offset(
inputs,
outputs,
write_pos_input,
comment!(input.clone()),
comptime![Some((dim + 1, config.rank))],
config,
);
index += index_after;
}

let coordinate_dim = write_pos_input / stride_dim_ref % shape_dim_ref;
let offset_dim = read_input::<u32>(
inputs,
outputs,
pos_indices,
index_indices,
coordinate_dim,
LayoutInfo::IsRef,
precision_indices,
config,
None,
);
let index = index[i] + offset_dim[0] * stride_input;

let input = read_input::<C>(
inputs,
outputs,
pos_input,
index,
LayoutInfo::IsRef,
precision_input,
config,
None,
);
result[i] = input[0];
index *= line_size_ref;
index += offset_dim[0] * stride_input_dim;

#[unroll]
for i in 0..line_size_ref {
let input = read_input::<C>(
inputs,
outputs,
pos_input,
index + i * stride_input_line,
LayoutInfo::IsRef,
precision_input,
config,
None,
);
result[i] = input[0];
}
} else {
// In this scenario the select is actually performed on the last dimension we're working on.
//
// Therefore we need to fetch multiple indices that correspond to different entries in the
// input tensor.

if comptime![dim > 0] {
let index_before = global_offset(
inputs,
outputs,
write_pos,
comment!(input.clone()),
comptime![Some((0u32, dim))],
config,
);
index += index_before;
}

if comptime![dim + 1 < config.rank] {
let index_after = global_offset(
inputs,
outputs,
write_pos,
input,
comptime![Some((dim + 1, config.rank))],
config,
);
index += index_after;
}

let write_pos_indices = write_pos * line_size_ref;

#[unroll]
for i in 0..line_size_ref {
let coordinate_dim = (write_pos_indices + i) / stride_dim_ref % shape_dim_ref;
let offset_dim = read_input::<u32>(
inputs,
outputs,
pos_indices,
coordinate_dim,
LayoutInfo::IsRef,
precision_indices,
config,
None,
);

let input = read_input::<C>(
inputs,
outputs,
pos_input,
index + (offset_dim[0] * stride_input_dim),
LayoutInfo::IsRef,
precision_input,
config,
None,
);
result[i] = input[0];
}
}

write::<C>(inputs, outputs, locals, write_pos, result, output, config);
Expand Down
13 changes: 13 additions & 0 deletions crates/burn-tensor/src/tests/ops/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ mod tests {
output.into_data().assert_eq(&expected, false);
}

#[test]
fn should_select_2d_dim0_vec() {
let device = Default::default();
let tensor =
TestTensor::<2>::from_data([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]], &device);
let indices = TestTensorInt::from_data([1, 0, 3, 2], &device);

let output = tensor.select(0, indices);
let expected = TensorData::from([[2.0, 3.0], [0.0, 1.0], [6.0, 7.0], [4.0, 5.0]]);

output.into_data().assert_eq(&expected, false);
}

#[test]
fn should_select_2d_dim1() {
let device = Default::default();
Expand Down
7 changes: 4 additions & 3 deletions crates/burn-tensor/src/tests/ops/topk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ mod tests {
}

#[test]
fn test_topk_with_indices() {
// 1D
fn test_topk_with_indices_1d() {
let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]);

let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0);
Expand All @@ -54,8 +53,10 @@ mod tests {

let indices_expected = TensorData::from([4, 3, 2]);
indices.into_data().assert_eq(&indices_expected, false);
}

// 3D
#[test]
fn test_topk_with_indices_3d() {
let tensor =
TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]);

Expand Down

0 comments on commit 1962fb1

Please sign in to comment.