Skip to content

Commit

Permalink
Run sort UTs
Browse files Browse the repository at this point in the history
  • Loading branch information
pragupta authored and pruthvistony committed Mar 4, 2025
1 parent a1efa0c commit 4b826b3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 16 deletions.
7 changes: 4 additions & 3 deletions aten/src/ATen/native/cuda/Sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ void sort_cuda_kernel(
"The dimension being sorted can not have more than INT_MAX elements.");

const auto self_dtype = self.dtype();
// FIXME: remove this check once cub sort supports bool
TORCH_CHECK(self_dtype != ScalarType::Bool,
"Sort currently does not support bool dtype on CUDA.");
TORCH_CHECK(self_dtype != ScalarType::ComplexFloat && self_dtype != ScalarType::ComplexDouble,
"Sort currently does not support complex dtypes on CUDA.");
#if defined(USE_ROCM)
Expand All @@ -76,6 +73,10 @@ void sort_cuda_kernel(
if (self_dtype == ScalarType::Bool) {
self.copy_(self.to(at::kByte));
}
#else
// FIXME: remove this check once cub sort supports bool
TORCH_CHECK(self_dtype != ScalarType::Bool,
"Sort currently does not support bool dtype on CUDA.");
#endif

// use inplace algorithm for smaller input sizes without stable=True
Expand Down
16 changes: 3 additions & 13 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18304,13 +18304,11 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
)),
OpInfo('sort',
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_sort,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values',
dtypes=[torch.bool], device_type='cuda', active_if=not TEST_WITH_ROCM),
)),
OpInfo('unique',
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64),
Expand Down Expand Up @@ -19337,7 +19335,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
sample_inputs_func=sample_inputs_unfold),
OpInfo('msort',
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
check_batched_gradgrad=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down Expand Up @@ -21117,7 +21115,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
OpInfo(
"argsort",
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_sort,
supports_out=False,
supports_autograd=False,
Expand All @@ -21128,14 +21126,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
"test_variant_consistency_jit",
dtypes=(torch.float32,),
),
DecorateInfo(
unittest.expectedFailure,
"TestCommon",
"test_non_standard_bool_values",
dtypes=[torch.bool],
device_type='cuda',
active_if=not TEST_WITH_ROCM
),
),
),
OpInfo(
Expand Down

0 comments on commit 4b826b3

Please sign in to comment.