Skip to content

Commit

Permalink
Fix sort descending for 1d case (#1494)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Mar 21, 2024
1 parent e8863da commit 3e4af41
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
3 changes: 2 additions & 1 deletion crates/burn-tensor/src/tensor/api/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ where
let dims = data.shape.dims;
if D == 1 {
// 1D sort
data.value.sort_unstable_by(|&a, &b| a.cmp(&b));
data.value
.sort_unstable_by(|&a, &b| compare(&a, &b, descending));
} else {
sort_slice::<B, D, K>(&mut data.value, &dims, dim, None, false, descending);
}
Expand Down
21 changes: 21 additions & 0 deletions crates/burn-tensor/src/tests/ops/sort_argsort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,25 @@ mod tests {
let values_expected = Data::from([[-0.5, 0.94], [-0.3, f32::NAN], [0., f32::NAN]]);
values_expected.assert_approx_eq(&values_actual, 5);
}

#[test]
fn test_sort_descending_1d() {
let tensor = TestTensorInt::from([1, 2, 3, 4, 5]);

// Sort along dim=0
let values = tensor.sort_descending(0);
let values_actual = values.into_data();

let values_expected = Data::from([5, 4, 3, 2, 1]);
assert_eq!(values_expected, values_actual);

let tensor = TestTensor::from([1., 2., 3., 4., 5.]);

// Sort along dim=0
let values = tensor.sort_descending(0);
let values_actual = values.into_data();

let values_expected = Data::from([5., 4., 3., 2., 1.]);
values_expected.assert_approx_eq(&values_actual, 5);
}
}

0 comments on commit 3e4af41

Please sign in to comment.