From 309d16058992f2a32ba549a29475c234010ba129 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 22 May 2024 14:28:41 +0400 Subject: [PATCH] dump model before compress cli --- optimum/exporters/openvino/convert.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 3b214f77e4..32be0c774a 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -19,11 +19,12 @@ import os from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +import tempfile import onnx from transformers.utils import is_tf_available, is_torch_available -from openvino.runtime import Model, PartialShape, save_model +from openvino.runtime import Model, PartialShape, save_model, Core from openvino.runtime.exceptions import OVTypeError from openvino.runtime.utils.types import get_element_type from openvino.tools.ovc import convert_model @@ -58,6 +59,7 @@ logger = logging.getLogger(__name__) +core = Core() if is_torch_available(): import torch.nn as nn @@ -412,11 +414,24 @@ def ts_patched_forward(*args, **kwargs): if stateful: patch_stateful(model.config, ov_model) + if ov_config.quantization_config: + with tempfile.TemporaryDirectory() as temp_dir: + tmp_output = Path(temp_dir) / output.name + _save_model(ov_model, tmp_output, ov_config=None) + clear_class_registry() + del ov_model + del model + gc.collect() + + ov_model = core.read_model(tmp_output) + _save_model(ov_model, output, ov_config) + else: + _save_model(ov_model, output, ov_config) + clear_class_registry() + del ov_model + del model + gc.collect() - _save_model(ov_model, output, ov_config=ov_config) - clear_class_registry() - del model - gc.collect() return input_names, output_names, False