19
19
from transformers .utils import is_tf_available
20
20
21
21
from optimum .exporters .onnx .config import TextDecoderOnnxConfig , TextDecoderWithPositionIdsOnnxConfig
22
- from optimum .exporters .openvino .model_patcher import ChatGLMModelPatcher , MixtralModelPatcher
22
+ from optimum .exporters .onnx .model_configs import GemmaOnnxConfig
23
+ from optimum .exporters .openvino .model_patcher import ChatGLMModelPatcher , GemmaModelPatcher , MixtralModelPatcher
23
24
from optimum .exporters .tasks import TasksManager
24
25
from optimum .utils import DEFAULT_DUMMY_SHAPES
25
26
from optimum .utils .input_generators import (
@@ -65,23 +66,23 @@ def init_model_configs():
65
66
register_in_tasks_manager = TasksManager .create_register ("openvino" , overwrite_existing = True )
66
67
67
68
68
- @register_in_tasks_manager ("baichuan" , * ["text-generation" , "text-generation-with-past" ])
69
+ @register_in_tasks_manager ("baichuan" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
69
70
class BaichaunOpenVINOConfig (TextDecoderOnnxConfig ):
70
71
DEFAULT_ONNX_OPSET = 13
71
72
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig .with_args (
72
73
num_layers = "num_hidden_layers" , num_attention_heads = "num_attention_heads" , hidden_size = "hidden_size"
73
74
)
74
75
75
76
76
- @register_in_tasks_manager ("jais" , * ["text-generation" , "text-generation-with-past" ])
77
+ @register_in_tasks_manager ("jais" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
77
78
class JaisOpenVINOConfig (TextDecoderOnnxConfig ):
78
79
DEFAULT_ONNX_OPSET = 13
79
80
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig .with_args (
80
81
num_layers = "n_layer" , num_attention_heads = "n_head" , hidden_size = "n_embd"
81
82
)
82
83
83
84
84
- @register_in_tasks_manager ("qwen2" , * ["text-generation" , "text-generation-with-past" ])
85
+ @register_in_tasks_manager ("qwen2" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
85
86
class Qwen2OpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
86
87
DEFAULT_ONNX_OPSET = 14
87
88
@@ -90,7 +91,7 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
90
91
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
91
92
92
93
93
- @register_in_tasks_manager ("minicpm" , * ["text-generation" , "text-generation-with-past" ])
94
+ @register_in_tasks_manager ("minicpm" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
94
95
class MiniCPMOpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
95
96
DEFAULT_ONNX_OPSET = 14
96
97
@@ -99,7 +100,7 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
99
100
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
100
101
101
102
102
- @register_in_tasks_manager ("stablelm" , * ["text-generation" , "text-generation-with-past" ])
103
+ @register_in_tasks_manager ("stablelm" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
103
104
class StableLMOpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
104
105
DEFAULT_ONNX_OPSET = 14
105
106
@@ -128,7 +129,7 @@ def __init__(
128
129
random_sequence_length_range = random_sequence_length_range ,
129
130
)
130
131
self .multi_query_group_num = normalized_config .multi_query_group_num
131
- self .head_dim = self . hidden_size // self . num_attention_heads
132
+ self .head_dim = normalized_config . kv_channels
132
133
133
134
def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
134
135
past_key_shape = (
@@ -152,7 +153,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
152
153
]
153
154
154
155
155
- @register_in_tasks_manager ("chatglm" , * ["text-generation" , "text-generation-with-past" ])
156
+ @register_in_tasks_manager ("chatglm" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
156
157
class ChatGLM2OpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
157
158
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig .with_args (vocab_size = "padded_vocab_size" , num_layers = "num_layers" )
158
159
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator , ChatGLM2DummyPastKeyValuesGenerator )
@@ -232,7 +233,7 @@ def patch_model_for_export(
232
233
return ChatGLMModelPatcher (self , model , model_kwargs = model_kwargs )
233
234
234
235
235
- @register_in_tasks_manager ("mixtral" , * ["text-generation" , "text-generation-with-past" ])
236
+ @register_in_tasks_manager ("mixtral" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
236
237
class MixtralOpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
237
238
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
238
239
MIN_TRANSFORMERS_VERSION = version .parse ("4.34.99" )
@@ -249,3 +250,21 @@ def patch_model_for_export(
249
250
self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
250
251
) -> "ModelPatcher" :
251
252
return MixtralModelPatcher (self , model , model_kwargs = model_kwargs )
253
+
254
+
255
+ @register_in_tasks_manager (
256
+ "gemma" ,
257
+ * [
258
+ "feature-extraction" ,
259
+ "feature-extraction-with-past" ,
260
+ "text-generation" ,
261
+ "text-generation-with-past" ,
262
+ "text-classification" ,
263
+ ],
264
+ library_name = "transformers" ,
265
+ )
266
+ class GemmaOpenVINOConfig (GemmaOnnxConfig ):
267
+ def patch_model_for_export (
268
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
269
+ ) -> "ModelPatcher" :
270
+ return GemmaModelPatcher (self , model , model_kwargs = model_kwargs )
0 commit comments