|
18 | 18 | from pathlib import Path
|
19 | 19 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
20 | 20 |
|
| 21 | +from transformers import PretrainedConfig |
21 | 22 | from transformers.utils import is_torch_available
|
22 | 23 |
|
23 | 24 | from openvino.runtime import Dimension, PartialShape, Symbol
|
24 | 25 | from openvino.runtime.utils.types import get_element_type
|
25 | 26 | from optimum.exporters import TasksManager
|
26 | 27 | from optimum.exporters.onnx.base import OnnxConfig
|
| 28 | +from optimum.intel.utils import is_transformers_version |
27 | 29 | from optimum.utils import is_diffusers_available
|
| 30 | +from optimum.utils.save_utils import maybe_save_preprocessors |
28 | 31 |
|
29 | 32 |
|
30 | 33 | logger = logging.getLogger(__name__)
|
@@ -227,3 +230,25 @@ def save_config(config, save_dir):
|
227 | 230 | save_dir.mkdir(exist_ok=True, parents=True)
|
228 | 231 | output_config_file = Path(save_dir / "config.json")
|
229 | 232 | config.to_json_file(output_config_file, use_diff=True)
|
| 233 | + |
| 234 | + |
| 235 | +def save_preprocessors( |
| 236 | + preprocessors: List, config: PretrainedConfig, output: Union[str, Path], trust_remote_code: bool |
| 237 | +): |
| 238 | + model_name_or_path = config._name_or_path |
| 239 | + if hasattr(config, "export_model_type"): |
| 240 | + model_type = config.export_model_type.replace("_", "-") |
| 241 | + else: |
| 242 | + model_type = config.model_type.replace("_", "-") |
| 243 | + if preprocessors is not None: |
| 244 | + # phi3-vision processor does not have chat_template attribute that breaks Processor saving on disk |
| 245 | + if is_transformers_version(">=", "4.45") and model_type == "phi3-v" and len(preprocessors) > 1: |
| 246 | + if not hasattr(preprocessors[1], "chat_template"): |
| 247 | + preprocessors[1].chat_template = getattr(preprocessors[0], "chat_template", None) |
| 248 | + for processor in preprocessors: |
| 249 | + try: |
| 250 | + processor.save_pretrained(output) |
| 251 | + except Exception as ex: |
| 252 | + logger.error(f"Saving {type(processor)} failed with {ex}") |
| 253 | + else: |
| 254 | + maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code) |
0 commit comments