Skip to content

Commit d216e3a

Browse files
committed
fix stateless model support
1 parent f82ff06 commit d216e3a

File tree

2 files changed

+104
-47
lines changed

2 files changed

+104
-47
lines changed

optimum/intel/openvino/modeling_decoder.py

+80-40
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414
import copy
1515
import logging
1616
import os
17-
import warnings
1817
from pathlib import Path
1918
from tempfile import TemporaryDirectory
20-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
19+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
2120

2221
import numpy as np
2322
import openvino
@@ -28,21 +27,10 @@
2827
from transformers import AutoModelForCausalLM, PretrainedConfig
2928
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
3029
from transformers.generation import GenerationMixin
31-
from transformers.generation.beam_search import BeamScorer, ConstrainedBeamSearchScorer
3230
from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
3331
from transformers.generation.logits_process import LogitsProcessorList
34-
from transformers.generation.stopping_criteria import (
35-
EosTokenCriteria,
36-
StoppingCriteriaList,
37-
validate_stopping_criteria,
38-
)
39-
from transformers.generation.utils import (
40-
GenerateBeamDecoderOnlyOutput,
41-
GenerateBeamOutput,
42-
GenerateOutput,
43-
_split_model_inputs,
44-
stack_model_outputs,
45-
)
32+
from transformers.generation.stopping_criteria import StoppingCriteriaList
33+
from transformers.generation.utils import GenerateOutput
4634
from transformers.modeling_outputs import CausalLMOutputWithPast
4735

4836
from optimum.utils.normalized_config import NormalizedConfigManager
@@ -398,7 +386,11 @@ def prepare_inputs(
398386
inputs = {}
399387
if not self.stateful:
400388
if past_key_values is not None:
401-
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
389+
if (
390+
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
391+
or self.config.model_type == "falcon"
392+
and self.config.new_decoder_architecture
393+
):
402394
if self._pkv_precision == Type.bf16:
403395
# numpy does not support bf16, pretending f16, should change to bf16
404396
past_key_values = tuple(
@@ -491,9 +483,6 @@ def forward(
491483
position_ids=position_ids,
492484
**kwargs,
493485
)
494-
495-
print(inputs["input_ids"].shape)
496-
497486
# Run inference
498487
self.request.start_async(inputs, share_inputs=True)
499488
self.request.wait()
@@ -509,7 +498,11 @@ def forward(
509498
if self.use_cache:
510499
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
511500
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
512-
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
501+
if (
502+
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
503+
or self.config.model_type == "falcon"
504+
and self.config.new_decoder_architecture
505+
):
513506
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
514507
past_key_values = tuple(
515508
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
@@ -561,21 +554,33 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
561554
return model_inputs
562555

563556
def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
557+
batch_size = logits.shape[0]
564558
if indicies.shape[0] != 1:
565559
logits = logits[indicies]
566560
if past_key_values and not self.stateful:
567-
past_key_values = tuple(
568-
tuple(
569-
past_state[indicies]
570-
if not self.config.model_type == "chatglm"
571-
else past_state[:, indicies, ...]
572-
for past_state in layer_past
561+
if (
562+
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
563+
or self.config.model_type == "falcon"
564+
and self.config.new_decoder_architecture
565+
):
566+
past_key_values = tuple(
567+
tuple(
568+
past_state[indicies]
569+
if not self.config.model_type == "chatglm"
570+
else past_state[:, indicies, ...]
571+
for past_state in layer_past
572+
)
573+
for layer_past in past_key_values
573574
)
574-
for layer_past in past_key_values
575-
)
576-
if self.stateful:
577-
self.next_beam_idx = self.next_beam_idx[indicies]
578-
self._second_iter_beam_search = True
575+
else:
576+
past_key_values = tuple([past_state[indicies] for past_state in past_key_values])
577+
if self.stateful:
578+
self.next_beam_idx = (
579+
self.next_beam_idx[indicies]
580+
if self.next_beam_idx is not None
581+
else np.arange(batch_size, dtype=int)[indicies]
582+
)
583+
self._second_iter_beam_search = True
579584
return logits, past_key_values
580585

581586
def _deduplicate_inputs(self, model_inputs: Dict):
@@ -591,12 +596,19 @@ def _deduplicate_inputs(self, model_inputs: Dict):
591596
else:
592597
shape = input_tensor.shape
593598
dtype = input_tensor.element_type
594-
shape[0 if not self.config.model_type == "chatglm" else 1] = indicies.shape[0]
599+
upd_batch_size = indicies.shape[0]
600+
if self.config.model_type == "bloom":
601+
upd_batch_size *= self.config.num_attention_heads
602+
shape[0 if not self.config.model_type == "chatglm" else 1] = upd_batch_size
595603
upd_model_inputs[input_name] = Tensor(dtype, shape)
596-
print(f"{input_name}: {upd_model_inputs[input_name].shape}")
597604
upd_model_inputs["input_ids"] = unique_input_ids
598605
if "beam_idx" in model_inputs:
599-
beam_idx = np.arange(unique_input_ids.shape[0], dtype=int)
606+
beam_range = (
607+
unique_input_ids.shape[0]
608+
if self.config.model_type != "bloom"
609+
else unique_input_ids.shape[0] * self.config.num_attention_heads
610+
)
611+
beam_idx = np.arange(beam_range, dtype=int)
600612
upd_model_inputs["beam_idx"] = beam_idx
601613
return upd_model_inputs, reverse_indicies
602614

@@ -646,7 +658,9 @@ def _get_past_length(self, past_key_values=None):
646658
return 0
647659
if self.stateful:
648660
return self._past_length
649-
if self.config.model_type in MULTI_QUERY_ATTN_MODELS:
661+
if self.config.model_type in MULTI_QUERY_ATTN_MODELS and not (
662+
self.config.model_type == "falcon" and self.config.new_decoder_architecture
663+
):
650664
return past_key_values[0].shape[-2]
651665
seq_length_dim = -2
652666
if self.config.model_type == "chatglm":
@@ -677,9 +691,14 @@ def _reorder_cache(
677691
self._second_iter_beam_search = False
678692
return past_key_values
679693
else:
680-
return tuple(
681-
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values
682-
)
694+
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS and not (
695+
self.config.model_type == "falcon" and self.config.new_decoder_architecture
696+
):
697+
return tuple(
698+
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past)
699+
for layer_past in past_key_values
700+
)
701+
return tuple(np.take(past_state, beam_idx, 0) for past_state in past_key_values)
683702

684703
def can_generate(self):
685704
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
@@ -800,11 +819,12 @@ def _reorder_cache(
800819
This is required to match `past_key_values` with the correct beam_idx at every generation step.
801820
"""
802821
if self.stateful:
803-
beam_idx = np.array(beam_idx)
804822
batch_size = beam_idx.shape[0]
823+
beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
805824
indices = np.array(range(batch_size * self.config.num_attention_heads))
806825
indices = indices.reshape([batch_size, self.config.num_attention_heads])
807826
self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
827+
self._second_iter_beam_search = False
808828
return past_key_values
809829
else:
810830
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))
@@ -854,14 +874,34 @@ def _convert_to_standard_cache(
854874
for layer_past in past_key_value
855875
)
856876

877+
def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
878+
batch_size = logits.shape[0]
879+
if indicies.shape[0] != 1:
880+
logits = logits[indicies]
881+
if past_key_values and not self.stateful:
882+
pkv_standard = self._convert_to_standard_cache(past_key_values, batch_size)
883+
pkv = tuple(tuple(past_state[indicies] for past_state in layer_past) for layer_past in pkv_standard)
884+
past_key_values = self._convert_to_bloom_cache(pkv)
885+
886+
if self.stateful:
887+
self.next_beam_idx = (
888+
self.next_beam_idx[indicies]
889+
if self.next_beam_idx is not None
890+
else np.arange(batch_size, dtype=int)[indicies]
891+
)
892+
self._second_iter_beam_search = True
893+
return logits, past_key_values
894+
857895

858896
class OVGPTBigCodeForCausalLM(OVModelForCausalLM):
859897
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
860898
def _reorder_cache(
861899
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
862900
) -> Tuple[Tuple[torch.Tensor]]:
863901
if self.stateful:
864-
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
902+
# save beam_idx to be used as an input in the next iteration
903+
self.next_beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
904+
self._second_iter_beam_search = False
865905
return past_key_values
866906
else:
867907
return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values)

tests/openvino/test_modeling.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -778,14 +778,31 @@ def test_default_filling_attention_mask_and_position_ids(self):
778778
del model_with_cache
779779
gc.collect()
780780

781-
def test_beam_search(self):
782-
model_id = MODEL_NAMES["llama"]
783-
ov_model_stateful = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=True)
784-
ov_model_stateless = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=False)
785-
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
781+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
782+
@pytest.mark.run_slow
783+
@slow
784+
def test_beam_search(self, model_arch):
785+
model_kwargs = {}
786+
model_id = MODEL_NAMES[model_arch]
787+
if model_arch in self.REMOTE_CODE_MODELS:
788+
model_kwargs = {
789+
"config": AutoConfig.from_pretrained(model_id, trust_remote_code=True),
790+
"trust_remote_code": True,
791+
}
792+
# Qwen tokenizer does not support padding, chatgm testing model produces nan that incompatible with beam search
793+
if model_arch in ["qwen", "chatglm"]:
794+
return
786795

787-
tokenizer = AutoTokenizer.from_pretrained(model_id)
788-
tokenizer.pad_token = tokenizer.eos_token
796+
ov_model_stateful = OVModelForCausalLM.from_pretrained(
797+
model_id, export=True, use_cache=True, stateful=True, **model_kwargs
798+
)
799+
ov_model_stateless = OVModelForCausalLM.from_pretrained(
800+
model_id, export=True, use_cache=True, stateful=False, **model_kwargs
801+
)
802+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
803+
804+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
805+
tokenizer.pad_token_id = tokenizer.eos_token_id
789806
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
790807
ov_model_stateful.generation_config.eos_token_id = None
791808
ov_model_stateless.generation_config.eos_token_id = None

0 commit comments

Comments
 (0)