Skip to content

Commit efcaed8

Browse files
committed
refacto test
1 parent 40dd236 commit efcaed8

File tree

1 file changed

+15
-39
lines changed

1 file changed

+15
-39
lines changed

tests/onnxruntime/test_modeling.py

+15-39
Original file line numberDiff line numberDiff line change
@@ -141,63 +141,39 @@ def __init__(self, *args, **kwargs):
141141
self.TINY_ONNX_SEQ2SEQ_MODEL_ID = "fxmarty/sshleifer-tiny-mbart-onnx"
142142
self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID = "optimum-internal-testing/tiny-stable-diffusion-onnx"
143143

144-
def test_load_onnx_model_from_hub(self):
144+
@parameterized.expand((ORTModelForCausalLM, ORTModel))
145+
def test_load_onnx_model_from_hub(self, model_cls):
145146
model_id = "optimum-internal-testing/tiny-random-llama"
146147
file_name = "model_optimized.onnx"
147148

148-
model = ORTModel.from_pretrained(model_id)
149+
model = model_cls.from_pretrained(model_id)
149150
self.assertEqual(model.model_path.name, "model.onnx")
150151

151-
model = ORTModel.from_pretrained(model_id, revision="onnx")
152+
model = model_cls.from_pretrained(model_id, revision="onnx")
152153
self.assertEqual(model.model_path.name, "model.onnx")
153154

154-
model = ORTModel.from_pretrained(model_id, revision="onnx", file_name=file_name)
155+
model = model_cls.from_pretrained(model_id, revision="onnx", file_name=file_name)
155156
self.assertEqual(model.model_path.name, file_name)
156157

157-
model = ORTModel.from_pretrained(model_id, revision="merged-onnx", file_name=file_name)
158+
model = model_cls.from_pretrained(model_id, revision="merged-onnx", file_name=file_name)
158159
self.assertEqual(model.model_path.name, file_name)
159160

160-
model = ORTModel.from_pretrained(model_id, revision="merged-onnx", subfolder="subfolder")
161-
self.assertEqual(model.model_path.name, "model.onnx")
162-
163-
model = ORTModel.from_pretrained(model_id, revision="merged-onnx", subfolder="subfolder", file_name=file_name)
164-
self.assertEqual(model.model_path.name, file_name)
165-
166-
model = ORTModel.from_pretrained(model_id, revision="merged-onnx", file_name="decoder_with_past_model.onnx")
167-
self.assertEqual(model.model_path.name, "decoder_with_past_model.onnx")
168-
169-
def test_load_decoder_onnx_model_from_hub(self):
170-
model_id = "optimum-internal-testing/tiny-random-llama"
171-
file_name = "model_optimized.onnx"
172-
173-
model = ORTModelForCausalLM.from_pretrained(model_id)
174-
self.assertEqual(model.model_path.name, "model.onnx")
161+
if model_cls is ORTModelForCausalLM:
162+
model = model_cls.from_pretrained(model_id, revision="merged-onnx")
163+
self.assertEqual(model.model_path.name, "decoder_model_merged.onnx")
175164

176-
model = ORTModelForCausalLM.from_pretrained(model_id, revision="onnx")
165+
model = model_cls.from_pretrained(model_id, revision="merged-onnx", subfolder="subfolder")
177166
self.assertEqual(model.model_path.name, "model.onnx")
178167

179-
model = ORTModelForCausalLM.from_pretrained(model_id, revision="onnx", file_name=file_name)
168+
model = model_cls.from_pretrained(model_id, revision="merged-onnx", subfolder="subfolder", file_name=file_name)
180169
self.assertEqual(model.model_path.name, file_name)
181170

182-
model = ORTModelForCausalLM.from_pretrained(model_id, revision="merged-onnx", file_name=file_name)
183-
self.assertEqual(model.model_path.name, file_name)
184-
185-
model = ORTModelForCausalLM.from_pretrained(model_id, revision="merged-onnx")
186-
self.assertEqual(model.model_path.name, "decoder_model_merged.onnx")
187-
188-
model = ORTModelForCausalLM.from_pretrained(model_id, revision="merged-onnx", subfolder="subfolder")
189-
self.assertEqual(model.model_path.name, "model.onnx")
190-
191-
model = ORTModelForCausalLM.from_pretrained(
192-
model_id, revision="merged-onnx", subfolder="subfolder", file_name=file_name
193-
)
194-
self.assertEqual(model.model_path.name, file_name)
195-
196-
model = ORTModelForCausalLM.from_pretrained(
197-
model_id, revision="merged-onnx", file_name="decoder_with_past_model.onnx"
198-
)
171+
model = model_cls.from_pretrained(model_id, revision="merged-onnx", file_name="decoder_with_past_model.onnx")
199172
self.assertEqual(model.model_path.name, "decoder_with_past_model.onnx")
200173

174+
with self.assertRaises(FileNotFoundError):
175+
model_cls.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM", file_name="test.onnx")
176+
201177
def test_load_model_from_local_path(self):
202178
model = ORTModel.from_pretrained(self.LOCAL_MODEL_PATH)
203179
self.assertIsInstance(model.model, onnxruntime.InferenceSession)

0 commit comments

Comments
 (0)