32
32
set_seed ,
33
33
)
34
34
35
- from optimum .exporters .onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
36
35
from optimum .intel import (
37
36
IPEXModel ,
38
37
IPEXModelForAudioClassification ,
@@ -236,11 +235,8 @@ def test_compare_to_transformers(self, model_arch):
236
235
return_tensors = "pt" ,
237
236
return_token_type_ids = False if model_arch in ("llama" , "llama2" ) else None ,
238
237
)
239
- position_ids = None
240
- if model_arch .replace ("_" , "-" ) in MODEL_TYPES_REQUIRING_POSITION_IDS :
241
- input_shape = tokens ["input_ids" ].shape
242
- position_ids = torch .arange (0 , input_shape [- 1 ], dtype = torch .long ).unsqueeze (0 ).view (- 1 , input_shape [- 1 ])
243
- outputs = ipex_model (** tokens , position_ids = position_ids )
238
+ inputs = ipex_model .prepare_inputs_for_generation (** tokens )
239
+ outputs = ipex_model (** inputs )
244
240
245
241
self .assertIsInstance (outputs .logits , torch .Tensor )
246
242
self .assertIsInstance (outputs .past_key_values , (tuple , list ))
@@ -267,12 +263,15 @@ def test_pipeline(self, model_arch):
267
263
def test_assisted_decoding (self , model_arch ):
268
264
model_id = MODEL_NAMES [model_arch ]
269
265
tokenizer = AutoTokenizer .from_pretrained (model_id )
270
- model = IPEXModelForCausalLM .from_pretrained (model_id , export = True )
271
- assistant_model = AutoModelForCausalLM .from_pretrained (model_id )
266
+ ipex_model = IPEXModelForCausalLM .from_pretrained (model_id , export = True )
267
+ transformers_model = AutoModelForCausalLM .from_pretrained (model_id )
272
268
tokens = tokenizer ("This is a sample input" , return_tensors = "pt" )
273
- output = model .generate (** tokens , do_sample = False )
274
- output_assisted = model .generate (** tokens , do_sample = False , assistant_model = assistant_model )
275
- self .assertTrue (torch .equal (output , output_assisted ))
269
+ ipex_output = ipex_model .generate (** tokens , do_sample = False )
270
+ ipex_output_assisted = ipex_model .generate (** tokens , do_sample = False , assistant_model = transformers_model )
271
+ transformers_output = transformers_model .generate (** tokens , do_sample = False )
272
+ transformers_output_assisted = transformers_model .generate (** tokens , do_sample = False , assistant_model = ipex_model )
273
+ self .assertTrue (torch .equal (ipex_output , ipex_output_assisted ))
274
+ self .assertTrue (torch .equal (transformers_output , transformers_output_assisted ))
276
275
277
276
@parameterized .expand (
278
277
grid_parameters (
0 commit comments