Skip to content

Commit 3fb7712

Browse files
authored
Fix input generator for falcon40b (#685)
* fix input generator for falcon40b * add test model
1 parent 4bee0c7 commit 3fb7712

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

optimum/exporters/openvino/model_configs.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919
from transformers.utils import is_tf_available
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
22-
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig
22+
from optimum.exporters.onnx.model_configs import FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig
2323
from optimum.exporters.tasks import TasksManager
2424
from optimum.utils import DEFAULT_DUMMY_SHAPES
2525
from optimum.utils.input_generators import (
2626
DummyInputGenerator,
2727
DummyPastKeyValuesGenerator,
2828
DummyTextInputGenerator,
29+
FalconDummyPastKeyValuesGenerator,
2930
MistralDummyPastKeyValuesGenerator,
3031
)
3132
from optimum.utils.normalized_config import NormalizedTextConfig
@@ -437,3 +438,50 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
437438
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
438439
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
439440
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
441+
442+
443+
class OVFalconDummyPastKeyValuesGenerator(FalconDummyPastKeyValuesGenerator):
444+
def __init__(
445+
self,
446+
task: str,
447+
normalized_config: NormalizedTextConfig,
448+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
449+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
450+
random_batch_size_range: Optional[Tuple[int, int]] = None,
451+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
452+
**kwargs,
453+
):
454+
super().__init__(
455+
task=task,
456+
normalized_config=normalized_config,
457+
batch_size=batch_size,
458+
sequence_length=sequence_length,
459+
random_batch_size_range=random_batch_size_range,
460+
random_sequence_length_range=random_sequence_length_range,
461+
**kwargs,
462+
)
463+
if normalized_config.new_decoder_architecture:
464+
self.num_kv_heads = normalized_config.num_attention_heads
465+
else:
466+
self.num_kv_heads = normalized_config.num_kv_heads if not normalized_config.multi_query else 1
467+
468+
self.head_dim = self.hidden_size // self.num_attention_heads
469+
470+
471+
@register_in_tasks_manager(
472+
"falcon",
473+
*[
474+
"feature-extraction",
475+
"feature-extraction-with-past",
476+
"question-answering",
477+
"text-generation",
478+
"text-generation-with-past",
479+
"token-classification",
480+
],
481+
library_name="transformers",
482+
)
483+
class FalconOpenVINOConfig(FalconOnnxConfig):
484+
DUMMY_INPUT_GENERATOR_CLASSES = (
485+
OVFalconDummyPastKeyValuesGenerator,
486+
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
487+
DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator

tests/openvino/test_modeling.py

+1
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
538538
"internlm2",
539539
"orion",
540540
"falcon",
541+
"falcon-40b",
541542
)
542543
GENERATION_LENGTH = 100
543544
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "olmo", "orion")

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"electra": "hf-internal-testing/tiny-random-electra",
4545
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
4646
"falcon": "fxmarty/really-tiny-falcon-testing",
47+
"falcon-40b": "katuni4ka/tiny-random-falcon-40b",
4748
"flaubert": "hf-internal-testing/tiny-random-flaubert",
4849
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
4950
"gpt2": "hf-internal-testing/tiny-random-gpt2",

0 commit comments

Comments
 (0)