Skip to content

Commit 0dbb104

Browse files
committed
fix stateless model support
1 parent f82ff06 commit 0dbb104

File tree

2 files changed

+105
-47
lines changed

2 files changed

+105
-47
lines changed

optimum/intel/openvino/modeling_decoder.py

+81-40
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import copy
15+
import warnings
1516
import logging
1617
import os
17-
import warnings
1818
from pathlib import Path
1919
from tempfile import TemporaryDirectory
20-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
20+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
2121

2222
import numpy as np
2323
import openvino
@@ -28,21 +28,10 @@
2828
from transformers import AutoModelForCausalLM, PretrainedConfig
2929
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
3030
from transformers.generation import GenerationMixin
31-
from transformers.generation.beam_search import BeamScorer, ConstrainedBeamSearchScorer
3231
from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
3332
from transformers.generation.logits_process import LogitsProcessorList
34-
from transformers.generation.stopping_criteria import (
35-
EosTokenCriteria,
36-
StoppingCriteriaList,
37-
validate_stopping_criteria,
38-
)
39-
from transformers.generation.utils import (
40-
GenerateBeamDecoderOnlyOutput,
41-
GenerateBeamOutput,
42-
GenerateOutput,
43-
_split_model_inputs,
44-
stack_model_outputs,
45-
)
33+
from transformers.generation.stopping_criteria import StoppingCriteriaList
34+
from transformers.generation.utils import GenerateOutput
4635
from transformers.modeling_outputs import CausalLMOutputWithPast
4736

4837
from optimum.utils.normalized_config import NormalizedConfigManager
@@ -398,7 +387,11 @@ def prepare_inputs(
398387
inputs = {}
399388
if not self.stateful:
400389
if past_key_values is not None:
401-
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
390+
if (
391+
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
392+
or self.config.model_type == "falcon"
393+
and self.config.new_decoder_architecture
394+
):
402395
if self._pkv_precision == Type.bf16:
403396
# numpy does not support bf16, pretending f16, should change to bf16
404397
past_key_values = tuple(
@@ -491,9 +484,6 @@ def forward(
491484
position_ids=position_ids,
492485
**kwargs,
493486
)
494-
495-
print(inputs["input_ids"].shape)
496-
497487
# Run inference
498488
self.request.start_async(inputs, share_inputs=True)
499489
self.request.wait()
@@ -509,7 +499,11 @@ def forward(
509499
if self.use_cache:
510500
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
511501
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
512-
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
502+
if (
503+
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
504+
or self.config.model_type == "falcon"
505+
and self.config.new_decoder_architecture
506+
):
513507
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
514508
past_key_values = tuple(
515509
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
@@ -561,21 +555,33 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
561555
return model_inputs
562556

563557
def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
558+
batch_size = logits.shape[0]
564559
if indicies.shape[0] != 1:
565560
logits = logits[indicies]
566561
if past_key_values and not self.stateful:
567-
past_key_values = tuple(
568-
tuple(
569-
past_state[indicies]
570-
if not self.config.model_type == "chatglm"
571-
else past_state[:, indicies, ...]
572-
for past_state in layer_past
562+
if (
563+
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
564+
or self.config.model_type == "falcon"
565+
and self.config.new_decoder_architecture
566+
):
567+
past_key_values = tuple(
568+
tuple(
569+
past_state[indicies]
570+
if not self.config.model_type == "chatglm"
571+
else past_state[:, indicies, ...]
572+
for past_state in layer_past
573+
)
574+
for layer_past in past_key_values
573575
)
574-
for layer_past in past_key_values
575-
)
576-
if self.stateful:
577-
self.next_beam_idx = self.next_beam_idx[indicies]
578-
self._second_iter_beam_search = True
576+
else:
577+
past_key_values = tuple([past_state[indicies] for past_state in past_key_values])
578+
if self.stateful:
579+
self.next_beam_idx = (
580+
self.next_beam_idx[indicies]
581+
if self.next_beam_idx is not None
582+
else np.arange(batch_size, dtype=int)[indicies]
583+
)
584+
self._second_iter_beam_search = True
579585
return logits, past_key_values
580586

581587
def _deduplicate_inputs(self, model_inputs: Dict):
@@ -591,12 +597,19 @@ def _deduplicate_inputs(self, model_inputs: Dict):
591597
else:
592598
shape = input_tensor.shape
593599
dtype = input_tensor.element_type
594-
shape[0 if not self.config.model_type == "chatglm" else 1] = indicies.shape[0]
600+
upd_batch_size = indicies.shape[0]
601+
if self.config.model_type == "bloom":
602+
upd_batch_size *= self.config.num_attention_heads
603+
shape[0 if not self.config.model_type == "chatglm" else 1] = upd_batch_size
595604
upd_model_inputs[input_name] = Tensor(dtype, shape)
596-
print(f"{input_name}: {upd_model_inputs[input_name].shape}")
597605
upd_model_inputs["input_ids"] = unique_input_ids
598606
if "beam_idx" in model_inputs:
599-
beam_idx = np.arange(unique_input_ids.shape[0], dtype=int)
607+
beam_range = (
608+
unique_input_ids.shape[0]
609+
if self.config.model_type != "bloom"
610+
else unique_input_ids.shape[0] * self.config.num_attention_heads
611+
)
612+
beam_idx = np.arange(beam_range, dtype=int)
600613
upd_model_inputs["beam_idx"] = beam_idx
601614
return upd_model_inputs, reverse_indicies
602615

@@ -646,7 +659,9 @@ def _get_past_length(self, past_key_values=None):
646659
return 0
647660
if self.stateful:
648661
return self._past_length
649-
if self.config.model_type in MULTI_QUERY_ATTN_MODELS:
662+
if self.config.model_type in MULTI_QUERY_ATTN_MODELS and not (
663+
self.config.model_type == "falcon" and self.config.new_decoder_architecture
664+
):
650665
return past_key_values[0].shape[-2]
651666
seq_length_dim = -2
652667
if self.config.model_type == "chatglm":
@@ -677,9 +692,14 @@ def _reorder_cache(
677692
self._second_iter_beam_search = False
678693
return past_key_values
679694
else:
680-
return tuple(
681-
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values
682-
)
695+
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS and not (
696+
self.config.model_type == "falcon" and self.config.new_decoder_architecture
697+
):
698+
return tuple(
699+
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past)
700+
for layer_past in past_key_values
701+
)
702+
return tuple(np.take(past_state, beam_idx, 0) for past_state in past_key_values)
683703

684704
def can_generate(self):
685705
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
@@ -800,11 +820,12 @@ def _reorder_cache(
800820
This is required to match `past_key_values` with the correct beam_idx at every generation step.
801821
"""
802822
if self.stateful:
803-
beam_idx = np.array(beam_idx)
804823
batch_size = beam_idx.shape[0]
824+
beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
805825
indices = np.array(range(batch_size * self.config.num_attention_heads))
806826
indices = indices.reshape([batch_size, self.config.num_attention_heads])
807827
self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
828+
self._second_iter_beam_search = False
808829
return past_key_values
809830
else:
810831
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))
@@ -854,14 +875,34 @@ def _convert_to_standard_cache(
854875
for layer_past in past_key_value
855876
)
856877

878+
def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
879+
batch_size = logits.shape[0]
880+
if indicies.shape[0] != 1:
881+
logits = logits[indicies]
882+
if past_key_values and not self.stateful:
883+
pkv_standard = self._convert_to_standard_cache(past_key_values, batch_size)
884+
pkv = tuple(tuple(past_state[indicies] for past_state in layer_past) for layer_past in pkv_standard)
885+
past_key_values = self._convert_to_bloom_cache(pkv)
886+
887+
if self.stateful:
888+
self.next_beam_idx = (
889+
self.next_beam_idx[indicies]
890+
if self.next_beam_idx is not None
891+
else np.arange(batch_size, dtype=int)[indicies]
892+
)
893+
self._second_iter_beam_search = True
894+
return logits, past_key_values
895+
857896

858897
class OVGPTBigCodeForCausalLM(OVModelForCausalLM):
859898
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
860899
def _reorder_cache(
861900
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
862901
) -> Tuple[Tuple[torch.Tensor]]:
863902
if self.stateful:
864-
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
903+
# save beam_idx to be used as an input in the next iteration
904+
self.next_beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
905+
self._second_iter_beam_search = False
865906
return past_key_values
866907
else:
867908
return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values)

tests/openvino/test_modeling.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -778,14 +778,31 @@ def test_default_filling_attention_mask_and_position_ids(self):
778778
del model_with_cache
779779
gc.collect()
780780

781-
def test_beam_search(self):
782-
model_id = MODEL_NAMES["llama"]
783-
ov_model_stateful = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=True)
784-
ov_model_stateless = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=False)
785-
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
781+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
782+
@pytest.mark.run_slow
783+
@slow
784+
def test_beam_search(self, model_arch):
785+
model_kwargs = {}
786+
model_id = MODEL_NAMES[model_arch]
787+
if model_arch in self.REMOTE_CODE_MODELS:
788+
model_kwargs = {
789+
"config": AutoConfig.from_pretrained(model_id, trust_remote_code=True),
790+
"trust_remote_code": True,
791+
}
792+
# Qwen tokenizer does not support padding, chatgm testing model produces nan that incompatible with beam search
793+
if model_arch in ["qwen", "chatglm"]:
794+
return
786795

787-
tokenizer = AutoTokenizer.from_pretrained(model_id)
788-
tokenizer.pad_token = tokenizer.eos_token
796+
ov_model_stateful = OVModelForCausalLM.from_pretrained(
797+
model_id, export=True, use_cache=True, stateful=True, **model_kwargs
798+
)
799+
ov_model_stateless = OVModelForCausalLM.from_pretrained(
800+
model_id, export=True, use_cache=True, stateful=False, **model_kwargs
801+
)
802+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
803+
804+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
805+
tokenizer.pad_token_id = tokenizer.eos_token_id
789806
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
790807
ov_model_stateful.generation_config.eos_token_id = None
791808
ov_model_stateless.generation_config.eos_token_id = None

0 commit comments

Comments
 (0)