Skip to content

Commit e1b6a59

Browse files
authored
Apply sdpa for mpt and internlm (#676)
* apply sdpa for mpt and internlm * fix bauchan-13b * fix accuracy * small refactoring * add test for baichuan 13b * add support output_attentions * code style
1 parent b017856 commit e1b6a59

File tree

5 files changed

+359
-3
lines changed

5 files changed

+359
-3
lines changed

optimum/exporters/openvino/convert.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def ts_patched_forward(*args, **kwargs):
358358

359359
with patcher:
360360
check_dummy_inputs_are_allowed(model, dummy_inputs)
361+
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
361362
inputs = config.ordered_inputs(model)
362363
input_names = list(inputs.keys())
363364
output_names = list(config.outputs.keys())
@@ -387,7 +388,6 @@ def ts_patched_forward(*args, **kwargs):
387388
ov_config=ov_config,
388389
)
389390

390-
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
391391
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
392392
if not ordered_dummy_inputs:
393393
ordered_dummy_inputs = dummy_inputs
@@ -403,7 +403,7 @@ def ts_patched_forward(*args, **kwargs):
403403
inp_tensor.get_tensor().set_names({input_name})
404404
inp_data = flatten_inputs[idx]
405405
static_shape = PartialShape(inp_data.shape)
406-
dims = inputs[input_name]
406+
dims = inputs.get(input_name, [])
407407
for dim in dims:
408408
static_shape[dim] = -1
409409
inp_tensor.get_node().set_partial_shape(static_shape)

optimum/exporters/openvino/model_configs.py

+18
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
FalconOnnxConfig,
2424
GemmaOnnxConfig,
2525
LlamaOnnxConfig,
26+
MPTOnnxConfig,
2627
PhiOnnxConfig,
2728
UNetOnnxConfig,
2829
VaeDecoderOnnxConfig,
@@ -43,8 +44,10 @@
4344
BaichuanModelPatcher,
4445
ChatGLMModelPatcher,
4546
GemmaModelPatcher,
47+
InternLMPatcher,
4648
LlamaModelPatcher,
4749
MixtralModelPatcher,
50+
MPTModelPatcher,
4851
Phi3ModelPatcher,
4952
QwenModelPatcher,
5053
)
@@ -439,6 +442,11 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
439442
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
440443
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
441444

445+
def patch_model_for_export(
446+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
447+
) -> "ModelPatcher":
448+
return InternLMPatcher(self, model, model_kwargs=model_kwargs)
449+
442450

443451
@register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers")
444452
class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
@@ -455,6 +463,16 @@ class OlmoOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
455463
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
456464

457465

466+
@register_in_tasks_manager(
467+
"mpt", *["text-generation", "text-generation-with-past", "text-classification"], library_name="transformers"
468+
)
469+
class MPTOpenVINOConfig(MPTOnnxConfig):
470+
def patch_model_for_export(
471+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
472+
) -> "ModelPatcher":
473+
return MPTModelPatcher(self, model, model_kwargs=model_kwargs)
474+
475+
458476
@register_in_tasks_manager(
459477
"phi3",
460478
*[

0 commit comments

Comments
 (0)