|
19 | 19 | import os
|
20 | 20 | from pathlib import Path
|
21 | 21 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
| 22 | +import tempfile |
22 | 23 |
|
23 | 24 | import onnx
|
24 | 25 | from transformers.utils import is_tf_available, is_torch_available
|
25 | 26 |
|
26 |
| -from openvino.runtime import Model, PartialShape, save_model |
| 27 | +from openvino.runtime import Model, PartialShape, save_model, Core |
27 | 28 | from openvino.runtime.exceptions import OVTypeError
|
28 | 29 | from openvino.runtime.utils.types import get_element_type
|
29 | 30 | from openvino.tools.ovc import convert_model
|
|
58 | 59 |
|
59 | 60 |
|
60 | 61 | logger = logging.getLogger(__name__)
|
| 62 | +core = Core() |
61 | 63 |
|
62 | 64 | if is_torch_available():
|
63 | 65 | import torch.nn as nn
|
@@ -412,11 +414,24 @@ def ts_patched_forward(*args, **kwargs):
|
412 | 414 |
|
413 | 415 | if stateful:
|
414 | 416 | patch_stateful(model.config, ov_model)
|
| 417 | + if ov_config.quantization_config: |
| 418 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 419 | + tmp_output = Path(temp_dir) / output.name |
| 420 | + _save_model(ov_model, tmp_output, ov_config=None) |
| 421 | + clear_class_registry() |
| 422 | + del ov_model |
| 423 | + del model |
| 424 | + gc.collect() |
| 425 | + |
| 426 | + ov_model = core.read_model(tmp_output) |
| 427 | + _save_model(ov_model, output, ov_config) |
| 428 | + else: |
| 429 | + _save_model(ov_model, output, ov_config) |
| 430 | + clear_class_registry() |
| 431 | + del ov_model |
| 432 | + del model |
| 433 | + gc.collect() |
415 | 434 |
|
416 |
| - _save_model(ov_model, output, ov_config=ov_config) |
417 |
| - clear_class_registry() |
418 |
| - del model |
419 |
| - gc.collect() |
420 | 435 | return input_names, output_names, False
|
421 | 436 |
|
422 | 437 |
|
|
0 commit comments