Skip to content

Commit 16c01d0

Browse files
committed
fix input generator for falcon40b
1 parent a0dc06c commit 16c01d0

File tree

1 file changed

+51
-1
lines changed

1 file changed

+51
-1
lines changed

optimum/exporters/openvino/model_configs.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
from transformers.utils import is_tf_available
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
22-
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig
22+
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig, FalconOnnxConfig
2323
from optimum.exporters.tasks import TasksManager
2424
from optimum.utils import DEFAULT_DUMMY_SHAPES
2525
from optimum.utils.input_generators import (
2626
DummyInputGenerator,
2727
DummyPastKeyValuesGenerator,
2828
DummyTextInputGenerator,
2929
MistralDummyPastKeyValuesGenerator,
30+
FalconDummyPastKeyValuesGenerator,
3031
)
3132
from optimum.utils.normalized_config import NormalizedTextConfig
3233

@@ -437,3 +438,52 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
437438
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
438439
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
439440
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
441+
442+
443+
class OVFalconDummyPastKeyValuesGenerator(FalconDummyPastKeyValuesGenerator):
444+
def __init__(
445+
self,
446+
task: str,
447+
normalized_config: NormalizedTextConfig,
448+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
449+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
450+
random_batch_size_range: Optional[Tuple[int, int]] = None,
451+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
452+
**kwargs,
453+
):
454+
super().__init__(
455+
task=task,
456+
normalized_config=normalized_config,
457+
batch_size=batch_size,
458+
sequence_length=sequence_length,
459+
random_batch_size_range=random_batch_size_range,
460+
random_sequence_length_range=random_sequence_length_range,
461+
**kwargs,
462+
)
463+
if normalized_config.new_decoder_architecture and normalized_config.multi_query:
464+
self.num_kv_heads = normalized_config.num_attention_heads
465+
elif normalized_config.new_decoder_architecture and not normalized_config.multi_query:
466+
self.num_kv_heads = normalized_config.num_kv_heads
467+
else:
468+
self.num_kv_heads = 1
469+
470+
self.head_dim = self.hidden_size // self.num_attention_heads
471+
472+
473+
@register_in_tasks_manager(
474+
"falcon",
475+
*[
476+
"feature-extraction",
477+
"feature-extraction-with-past",
478+
"question-answering",
479+
"text-generation",
480+
"text-generation-with-past",
481+
"token-classification",
482+
],
483+
library_name="transformers",
484+
)
485+
class FalconOpenVINOConfig(FalconOnnxConfig):
486+
DUMMY_INPUT_GENERATOR_CLASSES = (
487+
OVFalconDummyPastKeyValuesGenerator,
488+
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
489+
DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator

0 commit comments

Comments
 (0)