|
19 | 19 | from transformers.utils import is_tf_available
|
20 | 20 |
|
21 | 21 | from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
|
22 |
| -from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig |
| 22 | +from optimum.exporters.onnx.model_configs import FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig |
23 | 23 | from optimum.exporters.tasks import TasksManager
|
24 | 24 | from optimum.utils import DEFAULT_DUMMY_SHAPES
|
25 | 25 | from optimum.utils.input_generators import (
|
26 | 26 | DummyInputGenerator,
|
27 | 27 | DummyPastKeyValuesGenerator,
|
28 | 28 | DummyTextInputGenerator,
|
| 29 | + FalconDummyPastKeyValuesGenerator, |
29 | 30 | MistralDummyPastKeyValuesGenerator,
|
30 | 31 | )
|
31 | 32 | from optimum.utils.normalized_config import NormalizedTextConfig
|
@@ -437,3 +438,50 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
|
437 | 438 | DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
|
438 | 439 | DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
|
439 | 440 | NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
|
| 441 | + |
| 442 | + |
| 443 | +class OVFalconDummyPastKeyValuesGenerator(FalconDummyPastKeyValuesGenerator): |
| 444 | + def __init__( |
| 445 | + self, |
| 446 | + task: str, |
| 447 | + normalized_config: NormalizedTextConfig, |
| 448 | + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], |
| 449 | + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], |
| 450 | + random_batch_size_range: Optional[Tuple[int, int]] = None, |
| 451 | + random_sequence_length_range: Optional[Tuple[int, int]] = None, |
| 452 | + **kwargs, |
| 453 | + ): |
| 454 | + super().__init__( |
| 455 | + task=task, |
| 456 | + normalized_config=normalized_config, |
| 457 | + batch_size=batch_size, |
| 458 | + sequence_length=sequence_length, |
| 459 | + random_batch_size_range=random_batch_size_range, |
| 460 | + random_sequence_length_range=random_sequence_length_range, |
| 461 | + **kwargs, |
| 462 | + ) |
| 463 | + if normalized_config.new_decoder_architecture: |
| 464 | + self.num_kv_heads = normalized_config.num_attention_heads |
| 465 | + else: |
| 466 | + self.num_kv_heads = normalized_config.num_kv_heads if not normalized_config.multi_query else 1 |
| 467 | + |
| 468 | + self.head_dim = self.hidden_size // self.num_attention_heads |
| 469 | + |
| 470 | + |
| 471 | +@register_in_tasks_manager( |
| 472 | + "falcon", |
| 473 | + *[ |
| 474 | + "feature-extraction", |
| 475 | + "feature-extraction-with-past", |
| 476 | + "question-answering", |
| 477 | + "text-generation", |
| 478 | + "text-generation-with-past", |
| 479 | + "token-classification", |
| 480 | + ], |
| 481 | + library_name="transformers", |
| 482 | +) |
| 483 | +class FalconOpenVINOConfig(FalconOnnxConfig): |
| 484 | + DUMMY_INPUT_GENERATOR_CLASSES = ( |
| 485 | + OVFalconDummyPastKeyValuesGenerator, |
| 486 | + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES |
| 487 | + DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator |
0 commit comments