Skip to content

Commit 8257b86

Browse files
Skylion007pytorchmergebot
authored andcommitted
Add bfloat16 CUDA support to binomial distribution (pytorch#116932)
Now all distributions support bfloat16 as input. Pull Request resolved: pytorch#116932 Approved by: https://github.com/malfet
1 parent 4a37f57 commit 8257b86

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

aten/src/ATen/native/cuda/Distributions.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ void launch_binomial_cuda_kernel(
167167
std::lock_guard<std::mutex> lock(gen->mutex_);
168168
rng_engine_inputs = gen->philox_cuda_state(42);
169169
}
170-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.input_dtype(), "binomial_cuda", [&] {
170+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "binomial_cuda", [&] {
171171
binomial_cuda_kernel<scalar_t>(iter, rng_engine_inputs);
172172
});
173173
}

test/distributions/test_distributions.py

+3
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,9 @@ def test_binomial(self):
11511151
self._gradcheck_log_prob(lambda p: Binomial(total_count, None, p.log()), [p])
11521152
self.assertRaises(NotImplementedError, Binomial(10, p).rsample)
11531153

1154+
test_binomial_half = set_default_dtype(torch.float16)(test_binomial)
1155+
test_binomial_bfloat16 = set_default_dtype(torch.bfloat16)(test_binomial)
1156+
11541157
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
11551158
def test_binomial_sample(self):
11561159
set_rng_seed(0) # see Note [Randomized statistical tests]

0 commit comments

Comments
 (0)