Skip to content

Commit a85eae6

Browse files
committed
update codegen config for support codegen2
1 parent 7114900 commit a85eae6

File tree

4 files changed

+45
-0
lines changed

4 files changed

+45
-0
lines changed

optimum/exporters/openvino/model_configs.py

+14
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
2222
from optimum.exporters.onnx.model_configs import (
23+
CodeGenOnnxConfig,
2324
FalconOnnxConfig,
2425
GemmaOnnxConfig,
2526
LlamaOnnxConfig,
@@ -44,6 +45,7 @@
4445
AquilaModelPatcher,
4546
BaichuanModelPatcher,
4647
ChatGLMModelPatcher,
48+
CodeGenModelPatcher,
4749
GemmaModelPatcher,
4850
InternLM2Patcher,
4951
InternLMModelPatcher,
@@ -738,3 +740,15 @@ def patch_model_for_export(
738740
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
739741
) -> "ModelPatcher":
740742
return InternLMModelPatcher(self, model, model_kwargs=model_kwargs)
743+
744+
745+
@register_in_tasks_manager(
746+
"codegen",
747+
*["feature-extraction", "feature-extraction-with-past", "text-generation", "text-generation-with-past"],
748+
library_name="transformers",
749+
)
750+
class CodeGenOpenVINOConfig(CodeGenOnnxConfig):
751+
def patch_model_for_export(
752+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
753+
) -> "ModelPatcher":
754+
return CodeGenModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+28
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
from transformers.modeling_tf_utils import TFPreTrainedModel
4444

4545

46+
BETTERTRANSFORMER_IGNORE = ("codegen",)
47+
48+
4649
def patch_model_with_bettertransformer(model):
4750
COLOR_RED = "\033[1;31m"
4851
COLOR_RESET = "\033[0m"
@@ -81,6 +84,10 @@ def patch_model_with_bettertransformer(model):
8184
# model already has required SDPA implementation
8285
if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa":
8386
return model
87+
88+
if model.config.model_type in BETTERTRANSFORMER_IGNORE:
89+
return model
90+
8491
try:
8592
model = model.to_bettertransformer()
8693
except Exception as e:
@@ -1328,3 +1335,24 @@ def __exit__(self, exc_type, exc_value, traceback):
13281335
for layer in self._model.model.layers:
13291336
if hasattr(layer.self_attn, "_orig_forward"):
13301337
layer.self_attn.forward = layer.self_attn._orig_forward
1338+
1339+
1340+
class CodeGenModelPatcher(DecoderModelPatcher):
1341+
def __enter__(self):
1342+
super().__enter__()
1343+
1344+
# whole codegen bettertransformer patch include attn.forward and does not cover codegen2.
1345+
# For avoiding breaking model on tracing stage, we reduce area of bettertransformer patch only for _attn.
1346+
from optimum.bettertransformer.models.attention import codegen_wrapped_scaled_dot_product
1347+
1348+
for layer in self._model.transformer.h:
1349+
if is_torch_version(">=", "2.1.0") and not self._model.config.output_attentions:
1350+
orig_self_attn_fwd = layer.attn._attn
1351+
layer.attn._attn = types.MethodType(codegen_wrapped_scaled_dot_product, layer.attn)
1352+
layer.attn._orig_attn = orig_self_attn_fwd
1353+
1354+
def __exit__(self, exc_type, exc_value, traceback):
1355+
super().__exit__(exc_type, exc_value, traceback)
1356+
for layer in self._model.transformer.h:
1357+
if hasattr(layer.attn, "_orig_attn"):
1358+
layer.attn._attn = layer.attn._orig_attn

tests/openvino/test_modeling.py

+2
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
527527
"bloom",
528528
"chatglm",
529529
"codegen",
530+
"codegen2",
530531
# "data2vec-text", # TODO : enable when enabled in exporters
531532
"gemma",
532533
"gpt2",
@@ -577,6 +578,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
577578
"aquila2",
578579
"xverse",
579580
"internlm",
581+
"codegen2",
580582
)
581583

582584
@parameterized.expand(SUPPORTED_ARCHITECTURES)

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"cohere": "hf-internal-testing/tiny-random-CohereForCausalLM",
3838
"chatglm": "katuni4ka/tiny-random-chatglm2",
3939
"codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM",
40+
"codegen2": "katuni4ka/tiny-random-codegen2",
4041
"data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel",
4142
"data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel",
4243
"data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",

0 commit comments

Comments
 (0)