Skip to content

Commit 828dce1

Browse files
committed
gemma
1 parent 02889ac commit 828dce1

File tree

4 files changed

+87
-20
lines changed

4 files changed

+87
-20
lines changed

optimum/exporters/openvino/model_configs.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from transformers.utils import is_tf_available
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
22-
from optimum.exporters.openvino.model_patcher import ChatGLMModelPatcher, MixtralModelPatcher
22+
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig
23+
from optimum.exporters.openvino.model_patcher import ChatGLMModelPatcher, GemmaModelPatcher, MixtralModelPatcher
2324
from optimum.exporters.tasks import TasksManager
2425
from optimum.utils import DEFAULT_DUMMY_SHAPES
2526
from optimum.utils.input_generators import (
@@ -65,23 +66,23 @@ def init_model_configs():
6566
register_in_tasks_manager = TasksManager.create_register("openvino", overwrite_existing=True)
6667

6768

68-
@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"])
69+
@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"], library_name="transformers")
6970
class BaichaunOpenVINOConfig(TextDecoderOnnxConfig):
7071
DEFAULT_ONNX_OPSET = 13
7172
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
7273
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size"
7374
)
7475

7576

76-
@register_in_tasks_manager("jais", *["text-generation", "text-generation-with-past"])
77+
@register_in_tasks_manager("jais", *["text-generation", "text-generation-with-past"], library_name="transformers")
7778
class JaisOpenVINOConfig(TextDecoderOnnxConfig):
7879
DEFAULT_ONNX_OPSET = 13
7980
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
8081
num_layers="n_layer", num_attention_heads="n_head", hidden_size="n_embd"
8182
)
8283

8384

84-
@register_in_tasks_manager("qwen2", *["text-generation", "text-generation-with-past"])
85+
@register_in_tasks_manager("qwen2", *["text-generation", "text-generation-with-past"], library_name="transformers")
8586
class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
8687
DEFAULT_ONNX_OPSET = 14
8788

@@ -90,7 +91,7 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
9091
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
9192

9293

93-
@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"])
94+
@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers")
9495
class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
9596
DEFAULT_ONNX_OPSET = 14
9697

@@ -99,7 +100,7 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
99100
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
100101

101102

102-
@register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"])
103+
@register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"], library_name="transformers")
103104
class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
104105
DEFAULT_ONNX_OPSET = 14
105106

@@ -128,7 +129,7 @@ def __init__(
128129
random_sequence_length_range=random_sequence_length_range,
129130
)
130131
self.multi_query_group_num = normalized_config.multi_query_group_num
131-
self.head_dim = self.hidden_size // self.num_attention_heads
132+
self.head_dim = normalized_config.kv_channels
132133

133134
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
134135
past_key_shape = (
@@ -152,7 +153,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
152153
]
153154

154155

155-
@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"])
156+
@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"], library_name="transformers")
156157
class ChatGLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
157158
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(vocab_size="padded_vocab_size", num_layers="num_layers")
158159
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, ChatGLM2DummyPastKeyValuesGenerator)
@@ -232,7 +233,7 @@ def patch_model_for_export(
232233
return ChatGLMModelPatcher(self, model, model_kwargs=model_kwargs)
233234

234235

235-
@register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"])
236+
@register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"], library_name="transformers")
236237
class MixtralOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
237238
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
238239
MIN_TRANSFORMERS_VERSION = version.parse("4.34.99")
@@ -249,3 +250,21 @@ def patch_model_for_export(
249250
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
250251
) -> "ModelPatcher":
251252
return MixtralModelPatcher(self, model, model_kwargs=model_kwargs)
253+
254+
255+
@register_in_tasks_manager(
256+
"gemma",
257+
*[
258+
"feature-extraction",
259+
"feature-extraction-with-past",
260+
"text-generation",
261+
"text-generation-with-past",
262+
"text-classification",
263+
],
264+
library_name="transformers",
265+
)
266+
class GemmaOpenVINOConfig(GemmaOnnxConfig):
267+
def patch_model_for_export(
268+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
269+
) -> "ModelPatcher":
270+
return GemmaModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The HuggingFace Team. All rights reserved.
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -214,6 +214,33 @@ def _chatglm_transformer_forward(
214214
)
215215

216216

217+
def _chatglm2_get_context_layer(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor):
218+
mask = torch.zeros((query_layer.shape[-2], key_layer.shape[-2]), dtype=query_layer.dtype)
219+
if query_layer.shape[2] == key_layer.shape[2]:
220+
tmp_mask = torch.ones((query_layer.shape[-2], key_layer.shape[-2]), dtype=torch.bool).triu(diagonal=1)
221+
mask.masked_fill_(tmp_mask, float("-inf"))
222+
223+
context_layer = torch.nn.functional.scaled_dot_product_attention(
224+
query_layer, key_layer, value_layer, attn_mask=mask
225+
)
226+
return context_layer
227+
228+
229+
def _chatglm2_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask):
230+
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
231+
if attention_mask is None:
232+
context_layer = _chatglm2_get_context_layer(query_layer, key_layer, value_layer)
233+
else:
234+
context_layer = torch.nn.functional.scaled_dot_product_attention(
235+
query_layer, key_layer, value_layer, attention_mask
236+
)
237+
context_layer = context_layer.permute(2, 0, 1, 3)
238+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
239+
context_layer = context_layer.reshape(*new_context_layer_shape)
240+
241+
return context_layer
242+
243+
217244
class ChatGLMModelPatcher(DecoderModelPatcher):
218245
def __init__(
219246
self,
@@ -228,7 +255,27 @@ def __init__(
228255
def __enter__(self):
229256
super().__enter__()
230257
self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer)
258+
for block in self._model.transformer.encoder.layers:
259+
block.self_attention.core_attention._orig_forward = block.self_attention.core_attention.forward
260+
block.self_attention.core_attention.forward = types.MethodType(
261+
_chatglm2_core_attention_forward, block.self_attention.core_attention
262+
)
231263

232264
def __exit__(self, exc_type, exc_value, traceback):
233265
super().__exit__(exc_type, exc_value, traceback)
234266
self._model.transformer.forward = self.original_chatglm_transformer_forward
267+
for block in self._model.transformer.encoder.layers:
268+
block.self_attention.core_attention.forward = block.self_attention.core_attention._orig_forward
269+
270+
271+
class GemmaModelPatcher(DecoderModelPatcher):
272+
def __enter__(self):
273+
super().__enter__()
274+
275+
# init inv_freq for torchscript tracing
276+
for layer in self._model.model.layers:
277+
if layer.self_attn.rotary_emb.inv_freq is None:
278+
rotary_emb = layer.self_attn.rotary_emb
279+
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
280+
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
281+
)

tests/openvino/test_modeling.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from transformers.onnx.utils import get_preprocessor
5454
from utils_tests import MODEL_NAMES
5555

56-
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
5756
from optimum.intel import (
5857
OVModelForAudioClassification,
5958
OVModelForAudioFrameClassification,
@@ -481,12 +480,14 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
481480
"chatglm",
482481
"codegen",
483482
# "data2vec-text", # TODO : enable when enabled in exporters
483+
"gemma",
484484
"gpt2",
485485
"gpt_neo",
486486
"gpt_neox",
487487
"llama",
488488
# "llama_gptq",
489489
"marian",
490+
"minicpm",
490491
"mistral",
491492
"mixtral",
492493
"mpt",
@@ -497,15 +498,18 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
497498
)
498499
GENERATION_LENGTH = 100
499500
IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3")
500-
REMOTE_CODE_MODELS = ("chatglm",)
501+
REMOTE_CODE_MODELS = ("chatglm", "minicpm")
501502

502503
@parameterized.expand(SUPPORTED_ARCHITECTURES)
503504
def test_compare_to_transformers(self, model_arch):
504505
model_id = MODEL_NAMES[model_arch]
505-
not_stateful = ["gpt_bigcode", "llama"]
506+
not_stateful = ["gpt_bigcode"]
506507
if is_openvino_version("<", "2024.0"):
507508
not_stateful.append("mixtral")
508509

510+
if is_openvino_version("<", "2024.1"):
511+
not_stateful.extend(["llama", "gemma"])
512+
509513
if "gptq" in model_arch:
510514
self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM")
511515

@@ -528,11 +532,7 @@ def test_compare_to_transformers(self, model_arch):
528532
tokens = tokenizer(
529533
"This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
530534
)
531-
position_ids = None
532-
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
533-
input_shape = tokens["input_ids"].shape
534-
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1])
535-
ov_outputs = ov_model(**tokens, position_ids=position_ids)
535+
ov_outputs = ov_model(**tokens)
536536

537537
self.assertTrue("logits" in ov_outputs)
538538
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
@@ -542,12 +542,11 @@ def test_compare_to_transformers(self, model_arch):
542542
self.assertEqual(ov_model.stateful, is_stateful)
543543
if is_stateful:
544544
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)
545-
546545
with torch.no_grad():
547546
transformers_outputs = transformers_model(**tokens)
548547

549548
# Compare tensor outputs
550-
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
549+
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=1e-4))
551550
del transformers_model
552551
del ov_model
553552
gc.collect()

tests/openvino/utils_tests.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"convnext": "hf-internal-testing/tiny-random-convnext",
4040
"distilbert": "hf-internal-testing/tiny-random-distilbert",
4141
"electra": "hf-internal-testing/tiny-random-electra",
42+
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
4243
"flaubert": "hf-internal-testing/tiny-random-flaubert",
4344
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
4445
"gpt2": "hf-internal-testing/tiny-random-gpt2",
@@ -56,6 +57,7 @@
5657
"opt125m": "facebook/opt-125m",
5758
"marian": "sshleifer/tiny-marian-en-de",
5859
"mbart": "hf-internal-testing/tiny-random-mbart",
60+
"minicpm": "katuni4ka/tiny-random-minicpm",
5961
"mistral": "echarlaix/tiny-random-mistral",
6062
"mixtral": "TitanML/tiny-mixtral",
6163
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",

0 commit comments

Comments
 (0)