Skip to content

Commit 588738d

Browse files
authored
minicpm fix for bf16 (#1143)
1 parent cd44f82 commit 588738d

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

optimum/exporters/openvino/model_configs.py

+6
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
LlavaImageEmbeddingModelPatcher,
9696
LlavaQwen2ImageEmbeddingsModelPatcher,
9797
MiniCPM3Patcher,
98+
MiniCPMModelPatcher,
9899
MiniCPMVImageEmbeddingsModelPatcher,
99100
MiniCPMVResamplerModelPatcher,
100101
MistralModelPatcher,
@@ -221,6 +222,11 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
221222
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
222223
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
223224

225+
def patch_model_for_export(
226+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
227+
) -> "ModelPatcher":
228+
return MiniCPMModelPatcher(self, model, model_kwargs=model_kwargs)
229+
224230

225231
class OVMiniCPM3DummyPastKeyValuesGenerator(MistralDummyPastKeyValuesGenerator):
226232
def __init__(

optimum/exporters/openvino/model_patcher.py

+15
Original file line numberDiff line numberDiff line change
@@ -3905,3 +3905,18 @@ def __exit__(self, exc_type, exc_value, traceback):
39053905
for layer in self._model.layers:
39063906
if hasattr(layer.self_attn, "_orig_forward"):
39073907
layer.self_attn.forward = layer.self_attn._orig_forward
3908+
3909+
3910+
class MiniCPMModelPatcher(DecoderModelPatcher):
3911+
def __init__(
3912+
self,
3913+
config: "OnnxConfig",
3914+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
3915+
model_kwargs: Optional[Dict[str, Any]] = None,
3916+
):
3917+
for layer in model.model.layers:
3918+
if hasattr(layer, "scale_depth"):
3919+
layer.self_attn.o_proj.to(torch.float32)
3920+
layer.mlp.down_proj.to(torch.float32)
3921+
3922+
super().__init__(config, model, model_kwargs)

0 commit comments

Comments
 (0)