diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 90297c8fb3..66dfb5ae8f 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -19,13 +19,14 @@ from transformers.utils import is_tf_available from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig -from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig +from optimum.exporters.onnx.model_configs import FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig from optimum.exporters.tasks import TasksManager from optimum.utils import DEFAULT_DUMMY_SHAPES from optimum.utils.input_generators import ( DummyInputGenerator, DummyPastKeyValuesGenerator, DummyTextInputGenerator, + FalconDummyPastKeyValuesGenerator, MistralDummyPastKeyValuesGenerator, ) from optimum.utils.normalized_config import NormalizedTextConfig @@ -437,3 +438,50 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +class OVFalconDummyPastKeyValuesGenerator(FalconDummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + **kwargs, + ) + if normalized_config.new_decoder_architecture: + self.num_kv_heads = normalized_config.num_attention_heads + else: + self.num_kv_heads = normalized_config.num_kv_heads if not normalized_config.multi_query else 1 + + self.head_dim = self.hidden_size // self.num_attention_heads + + +@register_in_tasks_manager( + "falcon", + *[ + "feature-extraction", + "feature-extraction-with-past", + "question-answering", + "text-generation", + "text-generation-with-past", + "token-classification", + ], + library_name="transformers", +) +class FalconOpenVINOConfig(FalconOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + OVFalconDummyPastKeyValuesGenerator, + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index f84cac8161..b03f36d458 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -538,6 +538,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "internlm2", "orion", "falcon", + "falcon-40b", ) GENERATION_LENGTH = 100 REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "olmo", "orion") diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index ca56f6d552..a6f5e664ec 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -44,6 +44,7 @@ "electra": "hf-internal-testing/tiny-random-electra", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "falcon": "fxmarty/really-tiny-falcon-testing", + "falcon-40b": "katuni4ka/tiny-random-falcon-40b", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt2": "hf-internal-testing/tiny-random-gpt2",