Skip to content

Commit 7d327ed

Browse files
committed
Merge branch 'main' into support_phi3_export
2 parents 602b9c3 + b97b601 commit 7d327ed

File tree

6 files changed

+62
-3
lines changed

6 files changed

+62
-3
lines changed

optimum/exporters/openvino/model_configs.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919
from transformers.utils import is_tf_available
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
22-
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig, PhiOnnxConfig
22+
from optimum.exporters.onnx.model_configs import FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig, PhiOnnxConfig
2323
from optimum.exporters.tasks import TasksManager
2424
from optimum.utils import DEFAULT_DUMMY_SHAPES
2525
from optimum.utils.input_generators import (
2626
DummyInputGenerator,
2727
DummyPastKeyValuesGenerator,
2828
DummyTextInputGenerator,
29+
FalconDummyPastKeyValuesGenerator,
2930
MistralDummyPastKeyValuesGenerator,
3031
)
3132
from optimum.utils.normalized_config import NormalizedTextConfig
@@ -456,3 +457,50 @@ def patch_model_for_export(
456457
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
457458
) -> "ModelPatcher":
458459
return Phi3ModelPatcher(self, model, model_kwargs=model_kwargs)
460+
461+
462+
class OVFalconDummyPastKeyValuesGenerator(FalconDummyPastKeyValuesGenerator):
463+
def __init__(
464+
self,
465+
task: str,
466+
normalized_config: NormalizedTextConfig,
467+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
468+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
469+
random_batch_size_range: Optional[Tuple[int, int]] = None,
470+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
471+
**kwargs,
472+
):
473+
super().__init__(
474+
task=task,
475+
normalized_config=normalized_config,
476+
batch_size=batch_size,
477+
sequence_length=sequence_length,
478+
random_batch_size_range=random_batch_size_range,
479+
random_sequence_length_range=random_sequence_length_range,
480+
**kwargs,
481+
)
482+
if normalized_config.new_decoder_architecture:
483+
self.num_kv_heads = normalized_config.num_attention_heads
484+
else:
485+
self.num_kv_heads = normalized_config.num_kv_heads if not normalized_config.multi_query else 1
486+
487+
self.head_dim = self.hidden_size // self.num_attention_heads
488+
489+
490+
@register_in_tasks_manager(
491+
"falcon",
492+
*[
493+
"feature-extraction",
494+
"feature-extraction-with-past",
495+
"question-answering",
496+
"text-generation",
497+
"text-generation-with-past",
498+
"token-classification",
499+
],
500+
library_name="transformers",
501+
)
502+
class FalconOpenVINOConfig(FalconOnnxConfig):
503+
DUMMY_INPUT_GENERATOR_CLASSES = (
504+
OVFalconDummyPastKeyValuesGenerator,
505+
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
506+
DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator

optimum/intel/generation/modeling.py

+9
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,22 @@ def _reorder_cache(
180180
"""
181181
if self.config.model_type == "bloom":
182182
return self._reorder_cache_bloom(past_key_values, beam_idx)
183+
elif self.config.model_type == "gpt_bigcode":
184+
return self._reorder_cache_gpt_bigcode(past_key_values, beam_idx)
183185

184186
# from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
185187
return tuple(
186188
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
187189
for layer_past in past_key_values
188190
)
189191

192+
# Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
193+
@staticmethod
194+
def _reorder_cache_gpt_bigcode(
195+
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
196+
) -> Tuple[Tuple[torch.Tensor]]:
197+
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
198+
190199
# Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
191200
def _reorder_cache_bloom(
192201
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor

tests/generation/test_modeling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class ModelingIntegrationTest(unittest.TestCase):
5858
"mistral",
5959
"llama",
6060
"llama2",
61-
# "gpt_bigcode",
61+
"gpt_bigcode",
6262
)
6363

6464
GENERATION_LENGTH = 100

tests/ipex/test_inference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class IPEXIntegrationTest(unittest.TestCase):
6565
"gptj",
6666
"gpt2",
6767
"gpt_neo",
68-
# "gpt_bigcode",
68+
"gpt_bigcode",
6969
"llama",
7070
"llama2",
7171
"opt",

tests/openvino/test_modeling.py

+1
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
539539
"internlm2",
540540
"orion",
541541
"falcon",
542+
"falcon-40b",
542543
)
543544
GENERATION_LENGTH = 100
544545
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "olmo", "orion", "phi3")

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"electra": "hf-internal-testing/tiny-random-electra",
4545
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
4646
"falcon": "fxmarty/really-tiny-falcon-testing",
47+
"falcon-40b": "katuni4ka/tiny-random-falcon-40b",
4748
"flaubert": "hf-internal-testing/tiny-random-flaubert",
4849
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
4950
"gpt2": "hf-internal-testing/tiny-random-gpt2",

0 commit comments

Comments
 (0)