Skip to content

Commit fa49187

Browse files
authored
Add openvino export for InternLM2 and Orion architectures (#628)
* support more models in export * add orion * update tests
1 parent c935a3d commit fa49187

File tree

6 files changed

+41
-5
lines changed

6 files changed

+41
-5
lines changed

optimum/exporters/openvino/__main__.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def main_export(
202202
quantization_config = getattr(config, "quantization_config", None)
203203
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
204204
model_type = config.model_type.replace("_", "-")
205-
206205
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
207206
custom_architecture = True
208207
elif task not in TasksManager.get_supported_tasks_for_model_type(
@@ -220,6 +219,20 @@ def main_export(
220219
)
221220
if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
222221
loading_kwargs["attn_implementation"] = "eager"
222+
# there are some difference between remote and in library representation of past key values for some models,
223+
# for avoiding confusion we disable remote code for them
224+
if (
225+
trust_remote_code
226+
and model_type in {"falcon", "mpt", "phi"}
227+
and ("with-past" in task or original_task == "auto")
228+
and not custom_export_configs
229+
):
230+
logger.warning(
231+
f"Model type `{model_type}` export for task `{task}` is not supported for loading with `trust_remote_code=True`"
232+
"using default export configuration, `trust_remote_code` will be disabled. "
233+
"Please provide custom export config if you want load model with remote code."
234+
)
235+
trust_remote_code = False
223236

224237
# Patch the modules to export of GPTQ models w/o GPU
225238
if do_gptq_patching:

optimum/exporters/openvino/convert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def ts_patched_forward(*args, **kwargs):
345345
input_dict = dict(zip(keys, tuple_input))
346346
kwargs[input_name] = input_dict
347347
outputs = patched_forward(*args, **kwargs)
348-
return tuple(outputs.values())
348+
return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()])
349349

350350
patcher.patched_forward = ts_patched_forward
351351

optimum/exporters/openvino/model_configs.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def init_model_configs():
7474

7575

7676
@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"], library_name="transformers")
77-
class BaichaunOpenVINOConfig(TextDecoderOnnxConfig):
77+
class BaichaunOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
7878
DEFAULT_ONNX_OPSET = 13
7979
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
8080
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size"
@@ -400,3 +400,21 @@ class Starcoder2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
400400
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
401401
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
402402
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
403+
404+
405+
@register_in_tasks_manager("internlm2", *["text-generation", "text-generation-with-past"], library_name="transformers")
406+
class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
407+
DEFAULT_ONNX_OPSET = 14
408+
409+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
410+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
411+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
412+
413+
414+
@register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers")
415+
class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
416+
DEFAULT_ONNX_OPSET = 14
417+
418+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
419+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
420+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

optimum/exporters/openvino/model_patcher.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -513,5 +513,5 @@ def __init__(
513513
):
514514
super().__init__(config, model, model_kwargs)
515515
# model has first inference buffers initialization
516-
if self._model.lm_head.first_flag:
516+
if hasattr(self._model.lm_head, "first_flag"):
517517
self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))

tests/openvino/test_modeling.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -524,10 +524,12 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
524524
"stablelm",
525525
"starcoder2",
526526
"phi",
527+
"internlm2",
528+
"orion",
527529
)
528530
GENERATION_LENGTH = 100
529531
IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3")
530-
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen")
532+
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "olmo", "orion")
531533

532534
@parameterized.expand(SUPPORTED_ARCHITECTURES)
533535
def test_compare_to_transformers(self, model_arch):

tests/openvino/utils_tests.py

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
5151
"hubert": "hf-internal-testing/tiny-random-HubertModel",
5252
"ibert": "hf-internal-testing/tiny-random-ibert",
53+
"internlm2": "katuni4ka/tiny-random-internlm2",
5354
"levit": "hf-internal-testing/tiny-random-LevitModel",
5455
"longt5": "hf-internal-testing/tiny-random-longt5",
5556
"llama": "fxmarty/tiny-llama-fast-tokenizer",
@@ -69,6 +70,8 @@
6970
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
7071
"mt5": "stas/mt5-tiny-random",
7172
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
73+
"olmo": "katuni4ka/tiny-random-olmo",
74+
"orion": "katuni4ka/tiny-random-orion",
7275
"pegasus": "hf-internal-testing/tiny-random-pegasus",
7376
"pix2struct": "fxmarty/pix2struct-tiny-random",
7477
"phi": "echarlaix/tiny-random-PhiForCausalLM",

0 commit comments

Comments
 (0)