@@ -141,63 +141,39 @@ def __init__(self, *args, **kwargs):
141
141
self .TINY_ONNX_SEQ2SEQ_MODEL_ID = "fxmarty/sshleifer-tiny-mbart-onnx"
142
142
self .TINY_ONNX_STABLE_DIFFUSION_MODEL_ID = "optimum-internal-testing/tiny-stable-diffusion-onnx"
143
143
144
- def test_load_onnx_model_from_hub (self ):
144
+ @parameterized .expand ((ORTModelForCausalLM , ORTModel ))
145
+ def test_load_onnx_model_from_hub (self , model_cls ):
145
146
model_id = "optimum-internal-testing/tiny-random-llama"
146
147
file_name = "model_optimized.onnx"
147
148
148
- model = ORTModel .from_pretrained (model_id )
149
+ model = model_cls .from_pretrained (model_id )
149
150
self .assertEqual (model .model_path .name , "model.onnx" )
150
151
151
- model = ORTModel .from_pretrained (model_id , revision = "onnx" )
152
+ model = model_cls .from_pretrained (model_id , revision = "onnx" )
152
153
self .assertEqual (model .model_path .name , "model.onnx" )
153
154
154
- model = ORTModel .from_pretrained (model_id , revision = "onnx" , file_name = file_name )
155
+ model = model_cls .from_pretrained (model_id , revision = "onnx" , file_name = file_name )
155
156
self .assertEqual (model .model_path .name , file_name )
156
157
157
- model = ORTModel .from_pretrained (model_id , revision = "merged-onnx" , file_name = file_name )
158
+ model = model_cls .from_pretrained (model_id , revision = "merged-onnx" , file_name = file_name )
158
159
self .assertEqual (model .model_path .name , file_name )
159
160
160
- model = ORTModel .from_pretrained (model_id , revision = "merged-onnx" , subfolder = "subfolder" )
161
- self .assertEqual (model .model_path .name , "model.onnx" )
162
-
163
- model = ORTModel .from_pretrained (model_id , revision = "merged-onnx" , subfolder = "subfolder" , file_name = file_name )
164
- self .assertEqual (model .model_path .name , file_name )
165
-
166
- model = ORTModel .from_pretrained (model_id , revision = "merged-onnx" , file_name = "decoder_with_past_model.onnx" )
167
- self .assertEqual (model .model_path .name , "decoder_with_past_model.onnx" )
168
-
169
- def test_load_decoder_onnx_model_from_hub (self ):
170
- model_id = "optimum-internal-testing/tiny-random-llama"
171
- file_name = "model_optimized.onnx"
172
-
173
- model = ORTModelForCausalLM .from_pretrained (model_id )
174
- self .assertEqual (model .model_path .name , "model.onnx" )
161
+ if model_cls is ORTModelForCausalLM :
162
+ model = model_cls .from_pretrained (model_id , revision = "merged-onnx" )
163
+ self .assertEqual (model .model_path .name , "decoder_model_merged.onnx" )
175
164
176
- model = ORTModelForCausalLM .from_pretrained (model_id , revision = "onnx" )
165
+ model = model_cls .from_pretrained (model_id , revision = "merged- onnx" , subfolder = "subfolder " )
177
166
self .assertEqual (model .model_path .name , "model.onnx" )
178
167
179
- model = ORTModelForCausalLM .from_pretrained (model_id , revision = "onnx" , file_name = file_name )
168
+ model = model_cls .from_pretrained (model_id , revision = "merged- onnx" , subfolder = "subfolder " , file_name = file_name )
180
169
self .assertEqual (model .model_path .name , file_name )
181
170
182
- model = ORTModelForCausalLM .from_pretrained (model_id , revision = "merged-onnx" , file_name = file_name )
183
- self .assertEqual (model .model_path .name , file_name )
184
-
185
- model = ORTModelForCausalLM .from_pretrained (model_id , revision = "merged-onnx" )
186
- self .assertEqual (model .model_path .name , "decoder_model_merged.onnx" )
187
-
188
- model = ORTModelForCausalLM .from_pretrained (model_id , revision = "merged-onnx" , subfolder = "subfolder" )
189
- self .assertEqual (model .model_path .name , "model.onnx" )
190
-
191
- model = ORTModelForCausalLM .from_pretrained (
192
- model_id , revision = "merged-onnx" , subfolder = "subfolder" , file_name = file_name
193
- )
194
- self .assertEqual (model .model_path .name , file_name )
195
-
196
- model = ORTModelForCausalLM .from_pretrained (
197
- model_id , revision = "merged-onnx" , file_name = "decoder_with_past_model.onnx"
198
- )
171
+ model = model_cls .from_pretrained (model_id , revision = "merged-onnx" , file_name = "decoder_with_past_model.onnx" )
199
172
self .assertEqual (model .model_path .name , "decoder_with_past_model.onnx" )
200
173
174
+ with self .assertRaises (FileNotFoundError ):
175
+ model_cls .from_pretrained ("hf-internal-testing/tiny-random-LlamaForCausalLM" , file_name = "test.onnx" )
176
+
201
177
def test_load_model_from_local_path (self ):
202
178
model = ORTModel .from_pretrained (self .LOCAL_MODEL_PATH )
203
179
self .assertIsInstance (model .model , onnxruntime .InferenceSession )
0 commit comments