Skip to content

Commit 7c8c56f

Browse files
ctc and speech also uses convolution so has to be deterministic
1 parent 881015c commit 7c8c56f

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

tests/onnxruntime/test_modeling.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -3489,10 +3489,16 @@ def test_compare_to_io_binding(self, model_arch):
34893489

34903490
model_id = MODEL_NAMES[model_arch]
34913491
onnx_model = ORTModelForCTC.from_pretrained(
3492-
self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider"
3492+
self.onnx_model_dirs[model_arch],
3493+
use_io_binding=False,
3494+
provider="CUDAExecutionProvider",
3495+
provider_options={"cudnn_conv_algo_search": "DEFAULT"},
34933496
)
34943497
io_model = ORTModelForCTC.from_pretrained(
3495-
self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider"
3498+
self.onnx_model_dirs[model_arch],
3499+
use_io_binding=True,
3500+
provider="CUDAExecutionProvider",
3501+
provider_options={"cudnn_conv_algo_search": "DEFAULT"},
34963502
)
34973503

34983504
self.assertFalse(onnx_model.use_io_binding)
@@ -4713,10 +4719,20 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
47134719

47144720
model_id = MODEL_NAMES[model_arch]
47154721
onnx_model = ORTModelForSpeechSeq2Seq.from_pretrained(
4716-
self.onnx_model_dirs[test_name], use_io_binding=False, provider="CUDAExecutionProvider"
4722+
self.onnx_model_dirs[test_name],
4723+
use_io_binding=False,
4724+
provider="CUDAExecutionProvider",
4725+
provider_options={
4726+
"cudnn_conv_algo_search": "DEFAULT",
4727+
},
47174728
)
47184729
io_model = ORTModelForSpeechSeq2Seq.from_pretrained(
4719-
self.onnx_model_dirs[test_name], use_io_binding=True, provider="CUDAExecutionProvider"
4730+
self.onnx_model_dirs[test_name],
4731+
use_io_binding=True,
4732+
provider="CUDAExecutionProvider",
4733+
provider_options={
4734+
"cudnn_conv_algo_search": "DEFAULT",
4735+
},
47204736
)
47214737

47224738
self.assertFalse(onnx_model.use_io_binding)

0 commit comments

Comments
 (0)