Skip to content

Commit fe71151

Browse files
committed
test fixes for latest transformers and review fixes
1 parent be1a32d commit fe71151

File tree

2 files changed

+9
-19
lines changed

2 files changed

+9
-19
lines changed

optimum/intel/openvino/modeling_decoder.py

+5-14
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,6 @@ def compile(self):
338338
if self.compiled_model is None:
339339
super().compile()
340340
self.compiled_model = self.request
341-
# self.request = self.request.create_infer_request()
342341

343342
def _make_stateful(self):
344343
patch_stateful(self.config, self.model)
@@ -358,16 +357,11 @@ class OVModelForCausalLM(OVBaseDecoderModel, GenerationMixin):
358357

359358
def generate(self, *args, **kwargs):
360359
self.compile()
361-
infer_context = [self.compiled_model.create_infer_request()]
362-
kwargs["infer_context"] = infer_context
360+
if kwargs.get("infer_request") is None:
361+
infer_context = [self.compiled_model.create_infer_request()]
362+
kwargs["infer_context"] = infer_context
363363
return super().generate(*args, **kwargs)
364364

365-
def __call__(self, *args, **kwargs):
366-
self.compile()
367-
infer_context = [self.compiled_model.create_infer_request()]
368-
kwargs["infer_context"] = infer_context
369-
return super().__call__(*args, **kwargs)
370-
371365
@add_start_docstrings_to_model_forward(
372366
INPUTS_DOCSTRING.format("batch_size, sequence_length")
373367
+ TEXT_GENERATION_EXAMPLE.format(
@@ -482,7 +476,7 @@ def forward(
482476
# for stateful models, infer request is created in generate and __call_ methods and passed in the cycle via past_key_values param
483477
infer_request = past_key_values[1]
484478
else:
485-
if infer_context[0] is not None:
479+
if infer_context is not None:
486480
infer_request = infer_context[
487481
0
488482
] # Use passed inference request if provided in kwargs, create new one overwise
@@ -501,7 +495,7 @@ def forward(
501495
if not self.stateful:
502496
if self.use_cache:
503497
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
504-
past_key_values = tuple(infer_context[0].get_tensor(key).data for key in self.key_value_output_names)
498+
past_key_values = tuple(infer_request.get_tensor(key).data for key in self.key_value_output_names)
505499
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
506500
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
507501
past_key_values = tuple(
@@ -690,9 +684,6 @@ def _reorder_cache(
690684
batch_size = beam_idx.shape[0]
691685
indices = np.array(range(batch_size * self.config.num_attention_heads))
692686
indices = indices.reshape([batch_size, self.config.num_attention_heads])
693-
# self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
694-
# return past_key_values
695-
# print("_reorder_cache output",np.take(indices, beam_idx, 0).flatten())
696687
return ((np.take(indices, beam_idx, 0).flatten()), past_key_values[1])
697688
else:
698689
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))

tests/openvino/test_modeling.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -516,11 +516,9 @@ def test_compare_to_transformers(self, model_arch):
516516
input_shape = tokens["input_ids"].shape
517517
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1])
518518
ov_outputs = ov_model(**tokens, position_ids=position_ids)
519-
520-
self.assertTrue("logits" in ov_outputs)
521519
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
522520

523-
is_stateful = self.IS_SUPPORT_STATEFUL
521+
is_stateful = ov_model.config.model_type not in {"gpt_bigcode", "llama"} and self.IS_SUPPORT_STATEFUL
524522
self.assertEqual(ov_model.stateful, is_stateful)
525523

526524
with torch.no_grad():
@@ -541,7 +539,8 @@ def test_compare_to_transformers_multithreading(self, model_arch):
541539
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
542540
self.assertIsInstance(ov_model.config, PretrainedConfig)
543541
self.assertTrue(ov_model.use_cache)
544-
self.assertEqual(ov_model.stateful, self.IS_SUPPORT_STATEFUL)
542+
is_stateful = ov_model.config.model_type not in {"gpt_bigcode", "llama"} and self.IS_SUPPORT_STATEFUL
543+
self.assertEqual(ov_model.stateful, is_stateful)
545544
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
546545
tokenizer = AutoTokenizer.from_pretrained(model_id)
547546
inputs_list = ["This is a sample", "Here is another sample", "That's the thrid one", "This is the last sample"]
@@ -607,7 +606,7 @@ def run_ov_model(input_text, model):
607606
# Tokenizer is not supposed to be shared by multiple threads
608607
tokenizer = AutoTokenizer.from_pretrained(model_id)
609608
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
610-
outputs = pipe(input_text, max_length=10)
609+
outputs = pipe(input_text, max_length=30)
611610
self.assertEqual(pipe.device, model.device)
612611
for i in range(len(outputs)):
613612
self.assertTrue(all(input_text[i] in item["generated_text"] for item in outputs[i]))

0 commit comments

Comments
 (0)