File tree 1 file changed +11
-0
lines changed
1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -275,6 +275,17 @@ def test_pipeline(self, model_arch):
275
275
self .assertEqual (pipe .device , model .device )
276
276
self .assertTrue (all ("This is a sample" in item ["generated_text" ] for item in outputs ))
277
277
278
+ @parameterized .expand (SUPPORTED_ARCHITECTURES )
279
+ def test_assisted_decoding (self , model_arch ):
280
+ model_id = MODEL_NAMES [model_arch ]
281
+ tokenizer = AutoTokenizer .from_pretrained (model_id )
282
+ model = IPEXModelForCausalLM .from_pretrained (model_id , export = True )
283
+ assistant_model = AutoModelForCausalLM .from_pretrained (model_id )
284
+ tokens = tokenizer ("This is a sample input" , return_tensors = "pt" )
285
+ output = model .generate (** tokens , do_sample = False )
286
+ output_assisted = model .generate (** tokens , do_sample = False , assistant_model = assistant_model )
287
+ self .assertTrue (torch .equal (output , output_assisted ))
288
+
278
289
def test_compare_with_and_without_past_key_values (self ):
279
290
model_id = "echarlaix/tiny-random-gpt2-torchscript"
280
291
tokenizer = AutoTokenizer .from_pretrained (model_id )
You can’t perform that action at this time.
0 commit comments