Skip to content

Commit 2696e6f

Browse files
authored
Add export config for gemma2 (huggingface#876)
* add export config for gemma2 * update cache position and tests * update model list * fix without cache export * patch original torch gemma2 to work with dynamic cache * Update tests/openvino/test_modeling.py * prevent usage cache implementation * add min transformers version
1 parent d4e3128 commit 2696e6f

File tree

6 files changed

+112
-6
lines changed

6 files changed

+112
-6
lines changed

docs/source/openvino/models.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Here is the list of the supported architectures :
5555
- GPT-NeoX
5656
- GPT-NeoX-Japanese
5757
- Gemma
58+
- Gemma2
5859
- Hubert
5960
- IBert
6061
- InternLM

optimum/exporters/openvino/model_configs.py

+21
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
CodeGenModelPatcher,
5555
DBRXModelPatcher,
5656
FalconModelPatcher,
57+
Gemma2ModelPatcher,
5758
GptNeoxJapaneseModelPatcher,
5859
GptNeoxModelPatcher,
5960
InternLM2Patcher,
@@ -997,3 +998,23 @@ def patch_model_for_export(
997998
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
998999
) -> "ModelPatcher":
9991000
return GptNeoxModelPatcher(self, model, model_kwargs=model_kwargs)
1001+
1002+
1003+
@register_in_tasks_manager(
1004+
"gemma2",
1005+
*[
1006+
"feature-extraction",
1007+
"feature-extraction-with-past",
1008+
"text-generation",
1009+
"text-generation-with-past",
1010+
"text-classification",
1011+
],
1012+
library_name="transformers",
1013+
)
1014+
class Gemma2OpenVINOConfig(GemmaOnnxConfig):
1015+
MIN_TRANSFORMERS_VERSION = version.parse("4.43.0")
1016+
1017+
def patch_model_for_export(
1018+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1019+
) -> "ModelPatcher":
1020+
return Gemma2ModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+59-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
1516
import inspect
1617
import logging as log
1718
import math
@@ -23,7 +24,7 @@
2324
from transformers.modeling_outputs import BaseModelOutputWithPast
2425
from transformers.utils import is_tf_available
2526

26-
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher
27+
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, override_arguments
2728
from optimum.intel.utils.import_utils import (
2829
_openvino_version,
2930
_torch_version,
@@ -2409,3 +2410,60 @@ def __enter__(self):
24092410
super().__enter__()
24102411
for layer in self._model.gpt_neox_japanese.layers:
24112412
_reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb)
2413+
2414+
2415+
class Gemma2ModelPatcher(LlamaModelPatcher):
2416+
def __init__(
2417+
self,
2418+
config: "OnnxConfig",
2419+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
2420+
model_kwargs: Optional[Dict[str, Any]] = None,
2421+
):
2422+
super().__init__(config, model, model_kwargs)
2423+
2424+
@functools.wraps(self.orig_forward)
2425+
def patched_forward(*args, **kwargs):
2426+
from transformers.cache_utils import DynamicCache
2427+
2428+
signature = inspect.signature(self.orig_forward)
2429+
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
2430+
2431+
return_legacy_cache = False
2432+
pkv_in_args = False
2433+
legacy_pkv = None
2434+
if "past_key_values" in kwargs:
2435+
legacy_pkv = kwargs.pop("past_key_values", None)
2436+
sign_names = list(signature.parameters.keys())
2437+
pkv_argument_index = sign_names.index("past_key_values")
2438+
cache_position_index = sign_names.index("cache_position") if "cache_position" in sign_names else -1
2439+
input_ids_index = sign_names.index("input_ids" if "input_ids" in sign_names else "inputs_embeds")
2440+
if legacy_pkv is None and len(args) > pkv_argument_index:
2441+
legacy_pkv = args[pkv_argument_index]
2442+
pkv_in_args = True
2443+
if legacy_pkv is not None:
2444+
pkv = DynamicCache.from_legacy_cache(legacy_pkv)
2445+
return_legacy_cache = True
2446+
if not pkv_in_args:
2447+
kwargs["past_key_values"] = pkv
2448+
else:
2449+
args[pkv_argument_index] = pkv
2450+
2451+
if (
2452+
return_legacy_cache
2453+
and cache_position_index != -1
2454+
and (cache_position_index > len(args) and "cache_position" not in kwargs)
2455+
):
2456+
past_seen_tokens = legacy_pkv[0][0].shape[-2]
2457+
input_ids = args[input_ids_index]
2458+
cache_position = torch.arange(
2459+
past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device
2460+
)
2461+
kwargs["cache_position"] = cache_position
2462+
2463+
outputs = self.orig_forward(*args, **kwargs)
2464+
if return_legacy_cache:
2465+
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
2466+
2467+
return outputs
2468+
2469+
self.patched_forward = patched_forward

optimum/intel/openvino/modeling_decoder.py

+2
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,8 @@ def _from_pretrained(
806806
force_download=force_download,
807807
local_files_only=local_files_only,
808808
)
809+
if getattr(generation_config, "cache_implementation", None) is not None:
810+
generation_config.cache_implementation = None
809811
kwargs["generation_config"] = generation_config
810812
except Exception:
811813
pass

tests/openvino/test_modeling.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from transformers.testing_utils import slow
5858
from utils_tests import MODEL_NAMES
5959

60+
from optimum.exporters.openvino.model_patcher import patch_update_causal_mask
6061
from optimum.intel import (
6162
OVModelForAudioClassification,
6263
OVModelForAudioFrameClassification,
@@ -647,6 +648,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
647648
if is_transformers_version(">=", "4.40.0"):
648649
SUPPORTED_ARCHITECTURES += (
649650
"gemma",
651+
"gemma2",
650652
"olmo",
651653
"stablelm",
652654
"starcoder2",
@@ -728,7 +730,8 @@ def test_compare_to_transformers(self, model_arch):
728730
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=1e-4))
729731

730732
# Qwen tokenizer does not support padding
731-
if model_arch == "qwen":
733+
734+
if model_arch in ["qwen"]:
732735
return
733736

734737
if model_arch not in ["chatglm", "glm4", "persimmon"]:
@@ -753,7 +756,16 @@ def test_compare_to_transformers(self, model_arch):
753756
)
754757

755758
ov_outputs = ov_model.generate(**tokens, generation_config=gen_config)
756-
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
759+
additional_inputs = {}
760+
# gemma2 does not support dynamic cache, it is unfair to compare dynamic cache result vs hybrid cache,
761+
# align cache representation in torch model
762+
if model_arch == "gemma2":
763+
patch_update_causal_mask(transformers_model, "4.43.0")
764+
transformers_model._supports_cache_class = True
765+
from transformers.cache_utils import DynamicCache
766+
767+
additional_inputs = {"past_key_values": DynamicCache()}
768+
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config, **additional_inputs)
757769
self.assertTrue(torch.allclose(ov_outputs, transformers_outputs))
758770

759771
del transformers_model
@@ -921,8 +933,8 @@ def test_beam_search(self, model_arch):
921933
"config": AutoConfig.from_pretrained(model_id, trust_remote_code=True),
922934
"trust_remote_code": True,
923935
}
924-
# Qwen tokenizer does not support padding, chatgm testing model produces nan that incompatible with beam search
925-
if model_arch in ["qwen", "chatglm"]:
936+
# Qwen tokenizer does not support padding, chatglm, glm4 testing models produce nan that incompatible with beam search
937+
if model_arch in ["qwen", "chatglm", "glm4"]:
926938
return
927939

928940
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
@@ -988,6 +1000,12 @@ def test_beam_search(self, model_arch):
9881000

9891001
if model_arch == "arctic":
9901002
transformers_model.to(torch.float32)
1003+
additional_inputs = {}
1004+
# gemma2 does not support dynamic cache, it is unfair to compare dynamic cache result vs hybrid cache, align cache representation in torch model
1005+
if model_arch == "gemma2":
1006+
patch_update_causal_mask(transformers_model, "4.43.0")
1007+
transformers_model._supports_cache_class = True
1008+
from transformers.cache_utils import DynamicCache
9911009
tokenizer.pad_token_id = tokenizer.eos_token_id
9921010
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
9931011
tokens.pop("token_type_ids", None)
@@ -1002,7 +1020,12 @@ def test_beam_search(self, model_arch):
10021020
if gen_config.do_sample and model_arch in ["baichuan2-13b", "olmo"]:
10031021
continue
10041022
set_seed(SEED)
1005-
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
1023+
1024+
if model_arch == "gemma2":
1025+
additional_inputs = {"past_key_values": DynamicCache()}
1026+
transformers_outputs = transformers_model.generate(
1027+
**tokens, generation_config=gen_config, **additional_inputs
1028+
)
10061029
set_seed(SEED)
10071030
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
10081031
self.assertTrue(

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"electra": "hf-internal-testing/tiny-random-electra",
5555
"exaone": "katuni4ka/tiny-random-exaone",
5656
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
57+
"gemma2": "katuni4ka/tiny-random-gemma2",
5758
"falcon": "fxmarty/really-tiny-falcon-testing",
5859
"falcon-40b": "katuni4ka/tiny-random-falcon-40b",
5960
"flaubert": "hf-internal-testing/tiny-random-flaubert",

0 commit comments

Comments
 (0)