Skip to content

Commit 1b29066

Browse files
committed
gemma
1 parent 02889ac commit 1b29066

File tree

4 files changed

+85
-19
lines changed

4 files changed

+85
-19
lines changed

optimum/exporters/openvino/model_configs.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from transformers.utils import is_tf_available
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
22-
from optimum.exporters.openvino.model_patcher import ChatGLMModelPatcher, MixtralModelPatcher
22+
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig
23+
from optimum.exporters.openvino.model_patcher import ChatGLMModelPatcher, GemmaModelPatcher, MixtralModelPatcher
2324
from optimum.exporters.tasks import TasksManager
2425
from optimum.utils import DEFAULT_DUMMY_SHAPES
2526
from optimum.utils.input_generators import (
@@ -65,23 +66,23 @@ def init_model_configs():
6566
register_in_tasks_manager = TasksManager.create_register("openvino", overwrite_existing=True)
6667

6768

68-
@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"])
69+
@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"], library_name="transformers")
6970
class BaichaunOpenVINOConfig(TextDecoderOnnxConfig):
7071
DEFAULT_ONNX_OPSET = 13
7172
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
7273
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size"
7374
)
7475

7576

76-
@register_in_tasks_manager("jais", *["text-generation", "text-generation-with-past"])
77+
@register_in_tasks_manager("jais", *["text-generation", "text-generation-with-past"], library_name="transformers")
7778
class JaisOpenVINOConfig(TextDecoderOnnxConfig):
7879
DEFAULT_ONNX_OPSET = 13
7980
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
8081
num_layers="n_layer", num_attention_heads="n_head", hidden_size="n_embd"
8182
)
8283

8384

84-
@register_in_tasks_manager("qwen2", *["text-generation", "text-generation-with-past"])
85+
@register_in_tasks_manager("qwen2", *["text-generation", "text-generation-with-past"], library_name="transformers")
8586
class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
8687
DEFAULT_ONNX_OPSET = 14
8788

@@ -90,7 +91,7 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
9091
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
9192

9293

93-
@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"])
94+
@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers")
9495
class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
9596
DEFAULT_ONNX_OPSET = 14
9697

@@ -99,7 +100,7 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
99100
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
100101

101102

102-
@register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"])
103+
@register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"], library_name="transformers")
103104
class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
104105
DEFAULT_ONNX_OPSET = 14
105106

@@ -128,7 +129,7 @@ def __init__(
128129
random_sequence_length_range=random_sequence_length_range,
129130
)
130131
self.multi_query_group_num = normalized_config.multi_query_group_num
131-
self.head_dim = self.hidden_size // self.num_attention_heads
132+
self.head_dim = normalized_config.kv_channels
132133

133134
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
134135
past_key_shape = (
@@ -152,7 +153,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
152153
]
153154

154155

155-
@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"])
156+
@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"], library_name="transformers")
156157
class ChatGLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
157158
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(vocab_size="padded_vocab_size", num_layers="num_layers")
158159
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, ChatGLM2DummyPastKeyValuesGenerator)
@@ -232,7 +233,7 @@ def patch_model_for_export(
232233
return ChatGLMModelPatcher(self, model, model_kwargs=model_kwargs)
233234

234235

235-
@register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"])
236+
@register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"], library_name="transformers")
236237
class MixtralOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
237238
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
238239
MIN_TRANSFORMERS_VERSION = version.parse("4.34.99")
@@ -249,3 +250,21 @@ def patch_model_for_export(
249250
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
250251
) -> "ModelPatcher":
251252
return MixtralModelPatcher(self, model, model_kwargs=model_kwargs)
253+
254+
255+
@register_in_tasks_manager(
256+
"gemma",
257+
*[
258+
"feature-extraction",
259+
"feature-extraction-with-past",
260+
"text-generation",
261+
"text-generation-with-past",
262+
"text-classification",
263+
],
264+
library_name="transformers",
265+
)
266+
class GemmaOpenVINOConfig(GemmaOnnxConfig):
267+
def patch_model_for_export(
268+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
269+
) -> "ModelPatcher":
270+
return GemmaModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The HuggingFace Team. All rights reserved.
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -214,6 +214,34 @@ def _chatglm_transformer_forward(
214214
)
215215

216216

217+
@torch.jit.script_if_tracing
218+
def _chatglm2_get_context_layer(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor):
219+
mask = torch.zeros((query_layer.shape[-2], key_layer.shape[-2]), dtype=query_layer.dtype)
220+
if query_layer.shape[2] == key_layer.shape[2]:
221+
tmp_mask = torch.ones((query_layer.shape[-2], key_layer.shape[-2]), dtype=torch.bool).triu(diagonal=1)
222+
mask.masked_fill_(tmp_mask, float("-inf"))
223+
224+
context_layer = torch.nn.functional.scaled_dot_product_attention(
225+
query_layer, key_layer, value_layer, attn_mask=mask
226+
)
227+
return context_layer
228+
229+
230+
def _chatglm2_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask):
231+
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
232+
if attention_mask is None:
233+
context_layer = _chatglm2_get_context_layer(query_layer, key_layer, value_layer)
234+
else:
235+
context_layer = torch.nn.functional.scaled_dot_product_attention(
236+
query_layer, key_layer, value_layer, attention_mask
237+
)
238+
context_layer = context_layer.permute(2, 0, 1, 3)
239+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
240+
context_layer = context_layer.reshape(*new_context_layer_shape)
241+
242+
return context_layer
243+
244+
217245
class ChatGLMModelPatcher(DecoderModelPatcher):
218246
def __init__(
219247
self,
@@ -228,7 +256,27 @@ def __init__(
228256
def __enter__(self):
229257
super().__enter__()
230258
self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer)
259+
for block in self._model.transformer.encoder.layers:
260+
block.self_attention.core_attention._orig_forward = block.self_attention.core_attention.forward
261+
block.self_attention.core_attention.forward = types.MethodType(
262+
_chatglm2_core_attention_forward, block.self_attention.core_attention
263+
)
231264

232265
def __exit__(self, exc_type, exc_value, traceback):
233266
super().__exit__(exc_type, exc_value, traceback)
234267
self._model.transformer.forward = self.original_chatglm_transformer_forward
268+
for block in self._model.transformer.encoder.layers:
269+
block.self_attention.core_attention.forward = block.self_attention.core_attention._orig_forward
270+
271+
272+
class GemmaModelPatcher(DecoderModelPatcher):
273+
def __enter__(self):
274+
super().__enter__()
275+
276+
# init inv_freq for torchscript tracing
277+
for layer in self._model.model.layers:
278+
if layer.self_attn.rotary_emb.inv_freq is None:
279+
rotary_emb = layer.self_attn.rotary_emb
280+
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
281+
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
282+
)

tests/openvino/test_modeling.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from transformers.onnx.utils import get_preprocessor
5454
from utils_tests import MODEL_NAMES
5555

56-
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
5756
from optimum.intel import (
5857
OVModelForAudioClassification,
5958
OVModelForAudioFrameClassification,
@@ -481,6 +480,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
481480
"chatglm",
482481
"codegen",
483482
# "data2vec-text", # TODO : enable when enabled in exporters
483+
"gemma",
484484
"gpt2",
485485
"gpt_neo",
486486
"gpt_neox",
@@ -502,10 +502,13 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
502502
@parameterized.expand(SUPPORTED_ARCHITECTURES)
503503
def test_compare_to_transformers(self, model_arch):
504504
model_id = MODEL_NAMES[model_arch]
505-
not_stateful = ["gpt_bigcode", "llama"]
505+
not_stateful = ["gpt_bigcode"]
506506
if is_openvino_version("<", "2024.0"):
507507
not_stateful.append("mixtral")
508508

509+
if is_openvino_version("<", "2024.1"):
510+
not_stateful.extend(["llama", "gemma"])
511+
509512
if "gptq" in model_arch:
510513
self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM")
511514

@@ -528,11 +531,7 @@ def test_compare_to_transformers(self, model_arch):
528531
tokens = tokenizer(
529532
"This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
530533
)
531-
position_ids = None
532-
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
533-
input_shape = tokens["input_ids"].shape
534-
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1])
535-
ov_outputs = ov_model(**tokens, position_ids=position_ids)
534+
ov_outputs = ov_model(**tokens)
536535

537536
self.assertTrue("logits" in ov_outputs)
538537
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
@@ -542,12 +541,11 @@ def test_compare_to_transformers(self, model_arch):
542541
self.assertEqual(ov_model.stateful, is_stateful)
543542
if is_stateful:
544543
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)
545-
546544
with torch.no_grad():
547545
transformers_outputs = transformers_model(**tokens)
548546

549547
# Compare tensor outputs
550-
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
548+
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=1e-4))
551549
del transformers_model
552550
del ov_model
553551
gc.collect()

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"convnext": "hf-internal-testing/tiny-random-convnext",
4040
"distilbert": "hf-internal-testing/tiny-random-distilbert",
4141
"electra": "hf-internal-testing/tiny-random-electra",
42+
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
4243
"flaubert": "hf-internal-testing/tiny-random-flaubert",
4344
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
4445
"gpt2": "hf-internal-testing/tiny-random-gpt2",

0 commit comments

Comments
 (0)