|
20 | 20 | from typing import Dict, Optional, Union
|
21 | 21 |
|
22 | 22 | import openvino
|
23 |
| -from huggingface_hub import hf_hub_download |
| 23 | +from huggingface_hub import hf_hub_download, HfApi |
24 | 24 | from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
25 | 25 | from openvino import Core, convert_model
|
26 | 26 | from openvino._offline_transformations import apply_moc_transformations, compress_model_transformation
|
27 | 27 | from transformers import GenerationConfig, PretrainedConfig
|
28 | 28 | from transformers.file_utils import add_start_docstrings
|
29 | 29 | from transformers.generation import GenerationMixin
|
| 30 | +from transformers import AutoConfig |
30 | 31 |
|
31 | 32 | from optimum.exporters.onnx import OnnxConfig
|
| 33 | +from optimum.exporters.tasks import TasksManager |
32 | 34 | from optimum.modeling_base import OptimizedModel
|
| 35 | +from optimum.utils import CONFIG_NAME |
| 36 | +from optimum.modeling_base import FROM_PRETRAINED_START_DOCSTRING |
33 | 37 |
|
34 | 38 | from ...exporters.openvino import export, main_export
|
35 | 39 | from ..utils.import_utils import is_nncf_available
|
@@ -524,3 +528,127 @@ def can_generate(self) -> bool:
|
524 | 528 | if isinstance(self, GenerationMixin):
|
525 | 529 | return True
|
526 | 530 | return False
|
| 531 | + |
| 532 | + @classmethod |
| 533 | + @add_start_docstrings(FROM_PRETRAINED_START_DOCSTRING) |
| 534 | + def from_pretrained( |
| 535 | + cls, |
| 536 | + model_id: Union[str, Path], |
| 537 | + export: Optional[bool] = None, |
| 538 | + force_download: bool = False, |
| 539 | + use_auth_token: Optional[str] = None, |
| 540 | + cache_dir: str = HUGGINGFACE_HUB_CACHE, |
| 541 | + subfolder: str = "", |
| 542 | + config: Optional[PretrainedConfig] = None, |
| 543 | + local_files_only: bool = False, |
| 544 | + trust_remote_code: bool = False, |
| 545 | + revision: Optional[str] = None, |
| 546 | + **kwargs, |
| 547 | + ) -> "OptimizedModel": |
| 548 | + """ |
| 549 | + Returns: |
| 550 | + `OptimizedModel`: The loaded optimized model. |
| 551 | + """ |
| 552 | + if isinstance(model_id, Path): |
| 553 | + model_id = model_id.as_posix() |
| 554 | + |
| 555 | + from_transformers = kwargs.pop("from_transformers", None) |
| 556 | + if from_transformers is not None: |
| 557 | + logger.warning( |
| 558 | + "The argument `from_transformers` is deprecated, and will be removed in optimum 2.0. Use `export` instead" |
| 559 | + ) |
| 560 | + export = from_transformers |
| 561 | + |
| 562 | + if len(model_id.split("@")) == 2: |
| 563 | + if revision is not None: |
| 564 | + logger.warning( |
| 565 | + f"The argument `revision` was set to {revision} but will be ignored for {model_id.split('@')[1]}" |
| 566 | + ) |
| 567 | + model_id, revision = model_id.split("@") |
| 568 | + |
| 569 | + library_name = TasksManager.infer_library_from_model( |
| 570 | + model_id, subfolder, revision, cache_dir, use_auth_token=use_auth_token |
| 571 | + ) |
| 572 | + |
| 573 | + if library_name == "timm": |
| 574 | + config = PretrainedConfig.from_pretrained(model_id, subfolder, revision) |
| 575 | + |
| 576 | + if config is None: |
| 577 | + if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME: |
| 578 | + if CONFIG_NAME in os.listdir(os.path.join(model_id, subfolder)): |
| 579 | + config = AutoConfig.from_pretrained( |
| 580 | + os.path.join(model_id, subfolder, CONFIG_NAME), trust_remote_code=trust_remote_code |
| 581 | + ) |
| 582 | + elif CONFIG_NAME in os.listdir(model_id): |
| 583 | + config = AutoConfig.from_pretrained( |
| 584 | + os.path.join(model_id, CONFIG_NAME), trust_remote_code=trust_remote_code |
| 585 | + ) |
| 586 | + logger.info( |
| 587 | + f"config.json not found in the specified subfolder {subfolder}. Using the top level config.json." |
| 588 | + ) |
| 589 | + else: |
| 590 | + raise OSError(f"config.json not found in {model_id} local folder") |
| 591 | + else: |
| 592 | + config = cls._load_config( |
| 593 | + model_id, |
| 594 | + revision=revision, |
| 595 | + cache_dir=cache_dir, |
| 596 | + use_auth_token=use_auth_token, |
| 597 | + force_download=force_download, |
| 598 | + subfolder=subfolder, |
| 599 | + trust_remote_code=trust_remote_code, |
| 600 | + ) |
| 601 | + elif isinstance(config, (str, os.PathLike)): |
| 602 | + config = cls._load_config( |
| 603 | + config, |
| 604 | + revision=revision, |
| 605 | + cache_dir=cache_dir, |
| 606 | + use_auth_token=use_auth_token, |
| 607 | + force_download=force_download, |
| 608 | + subfolder=subfolder, |
| 609 | + trust_remote_code=trust_remote_code, |
| 610 | + ) |
| 611 | + |
| 612 | + if export is None: |
| 613 | + export = cls._check_export_status(model_id, revision, subfolder) |
| 614 | + |
| 615 | + if not export and trust_remote_code: |
| 616 | + logger.warning( |
| 617 | + "The argument `trust_remote_code` is to be used along with export=True. It will be ignored." |
| 618 | + ) |
| 619 | + elif export and trust_remote_code is None: |
| 620 | + trust_remote_code = False |
| 621 | + |
| 622 | + |
| 623 | + from_pretrained_method = cls._from_transformers if export else cls._from_pretrained |
| 624 | + |
| 625 | + return from_pretrained_method( |
| 626 | + model_id=model_id, |
| 627 | + config=config, |
| 628 | + revision=revision, |
| 629 | + cache_dir=cache_dir, |
| 630 | + force_download=force_download, |
| 631 | + use_auth_token=use_auth_token, |
| 632 | + subfolder=subfolder, |
| 633 | + local_files_only=local_files_only, |
| 634 | + trust_remote_code=trust_remote_code, |
| 635 | + **kwargs, |
| 636 | + ) |
| 637 | + |
| 638 | + @classmethod |
| 639 | + def _check_export_status(cls, model_id: Union[str, Path], revision: Optional[str] = None, subfolder: str = ""): |
| 640 | + model_dir = Path(model_id) |
| 641 | + if subfolder is not None: |
| 642 | + model_dir = model_dir / subfolder |
| 643 | + if model_dir.is_dir(): |
| 644 | + return not (model_dir / OV_XML_FILE_NAME).exists() or not (model_dir / OV_XML_FILE_NAME.replace(".xml", ".bin")).exists() |
| 645 | + |
| 646 | + hf_api = HfApi() |
| 647 | + try: |
| 648 | + model_info = hf_api.model_info(model_id, revision=revision or "main") |
| 649 | + normalized_subfolder = None if subfolder is None else Path(subfolder).as_posix() |
| 650 | + model_files = [file.rfilename for file in model_info.siblings if normalized_subfolder is None or file.rfilename.startswith(normalized_subfolder)] |
| 651 | + ov_model_path = OV_XML_FILE_NAME if subfolder is None else f"{normalized_subfolder}/{OV_XML_FILE_NAME}" |
| 652 | + return not ov_model_path in model_files or not ov_model_path.replace(".xml", ".bin") in model_files |
| 653 | + except Exception: |
| 654 | + return True |
0 commit comments