Skip to content

Commit b5d0069

Browse files
committed
optionally enable export if not exported model provided
1 parent 7a929e8 commit b5d0069

File tree

2 files changed

+154
-2
lines changed

2 files changed

+154
-2
lines changed

optimum/intel/openvino/modeling_base.py

+129-1
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,20 @@
2020
from typing import Dict, Optional, Union
2121

2222
import openvino
23-
from huggingface_hub import hf_hub_download
23+
from huggingface_hub import hf_hub_download, HfApi
2424
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
2525
from openvino import Core, convert_model
2626
from openvino._offline_transformations import apply_moc_transformations, compress_model_transformation
2727
from transformers import GenerationConfig, PretrainedConfig
2828
from transformers.file_utils import add_start_docstrings
2929
from transformers.generation import GenerationMixin
30+
from transformers import AutoConfig
3031

3132
from optimum.exporters.onnx import OnnxConfig
33+
from optimum.exporters.tasks import TasksManager
3234
from optimum.modeling_base import OptimizedModel
35+
from optimum.utils import CONFIG_NAME
36+
from optimum.modeling_base import FROM_PRETRAINED_START_DOCSTRING
3337

3438
from ...exporters.openvino import export, main_export
3539
from ..utils.import_utils import is_nncf_available
@@ -524,3 +528,127 @@ def can_generate(self) -> bool:
524528
if isinstance(self, GenerationMixin):
525529
return True
526530
return False
531+
532+
@classmethod
533+
@add_start_docstrings(FROM_PRETRAINED_START_DOCSTRING)
534+
def from_pretrained(
535+
cls,
536+
model_id: Union[str, Path],
537+
export: Optional[bool] = None,
538+
force_download: bool = False,
539+
use_auth_token: Optional[str] = None,
540+
cache_dir: str = HUGGINGFACE_HUB_CACHE,
541+
subfolder: str = "",
542+
config: Optional[PretrainedConfig] = None,
543+
local_files_only: bool = False,
544+
trust_remote_code: bool = False,
545+
revision: Optional[str] = None,
546+
**kwargs,
547+
) -> "OptimizedModel":
548+
"""
549+
Returns:
550+
`OptimizedModel`: The loaded optimized model.
551+
"""
552+
if isinstance(model_id, Path):
553+
model_id = model_id.as_posix()
554+
555+
from_transformers = kwargs.pop("from_transformers", None)
556+
if from_transformers is not None:
557+
logger.warning(
558+
"The argument `from_transformers` is deprecated, and will be removed in optimum 2.0. Use `export` instead"
559+
)
560+
export = from_transformers
561+
562+
if len(model_id.split("@")) == 2:
563+
if revision is not None:
564+
logger.warning(
565+
f"The argument `revision` was set to {revision} but will be ignored for {model_id.split('@')[1]}"
566+
)
567+
model_id, revision = model_id.split("@")
568+
569+
library_name = TasksManager.infer_library_from_model(
570+
model_id, subfolder, revision, cache_dir, use_auth_token=use_auth_token
571+
)
572+
573+
if library_name == "timm":
574+
config = PretrainedConfig.from_pretrained(model_id, subfolder, revision)
575+
576+
if config is None:
577+
if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME:
578+
if CONFIG_NAME in os.listdir(os.path.join(model_id, subfolder)):
579+
config = AutoConfig.from_pretrained(
580+
os.path.join(model_id, subfolder, CONFIG_NAME), trust_remote_code=trust_remote_code
581+
)
582+
elif CONFIG_NAME in os.listdir(model_id):
583+
config = AutoConfig.from_pretrained(
584+
os.path.join(model_id, CONFIG_NAME), trust_remote_code=trust_remote_code
585+
)
586+
logger.info(
587+
f"config.json not found in the specified subfolder {subfolder}. Using the top level config.json."
588+
)
589+
else:
590+
raise OSError(f"config.json not found in {model_id} local folder")
591+
else:
592+
config = cls._load_config(
593+
model_id,
594+
revision=revision,
595+
cache_dir=cache_dir,
596+
use_auth_token=use_auth_token,
597+
force_download=force_download,
598+
subfolder=subfolder,
599+
trust_remote_code=trust_remote_code,
600+
)
601+
elif isinstance(config, (str, os.PathLike)):
602+
config = cls._load_config(
603+
config,
604+
revision=revision,
605+
cache_dir=cache_dir,
606+
use_auth_token=use_auth_token,
607+
force_download=force_download,
608+
subfolder=subfolder,
609+
trust_remote_code=trust_remote_code,
610+
)
611+
612+
if export is None:
613+
export = cls._check_export_status(model_id, revision, subfolder)
614+
615+
if not export and trust_remote_code:
616+
logger.warning(
617+
"The argument `trust_remote_code` is to be used along with export=True. It will be ignored."
618+
)
619+
elif export and trust_remote_code is None:
620+
trust_remote_code = False
621+
622+
623+
from_pretrained_method = cls._from_transformers if export else cls._from_pretrained
624+
625+
return from_pretrained_method(
626+
model_id=model_id,
627+
config=config,
628+
revision=revision,
629+
cache_dir=cache_dir,
630+
force_download=force_download,
631+
use_auth_token=use_auth_token,
632+
subfolder=subfolder,
633+
local_files_only=local_files_only,
634+
trust_remote_code=trust_remote_code,
635+
**kwargs,
636+
)
637+
638+
@classmethod
639+
def _check_export_status(cls, model_id: Union[str, Path], revision: Optional[str] = None, subfolder: str = ""):
640+
model_dir = Path(model_id)
641+
if subfolder is not None:
642+
model_dir = model_dir / subfolder
643+
if model_dir.is_dir():
644+
return not (model_dir / OV_XML_FILE_NAME).exists() or not (model_dir / OV_XML_FILE_NAME.replace(".xml", ".bin")).exists()
645+
646+
hf_api = HfApi()
647+
try:
648+
model_info = hf_api.model_info(model_id, revision=revision or "main")
649+
normalized_subfolder = None if subfolder is None else Path(subfolder).as_posix()
650+
model_files = [file.rfilename for file in model_info.siblings if normalized_subfolder is None or file.rfilename.startswith(normalized_subfolder)]
651+
ov_model_path = OV_XML_FILE_NAME if subfolder is None else f"{normalized_subfolder}/{OV_XML_FILE_NAME}"
652+
return not ov_model_path in model_files or not ov_model_path.replace(".xml", ".bin") in model_files
653+
except Exception:
654+
return True

optimum/intel/openvino/modeling_base_seq2seq.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Dict, Optional, Union
2121

2222
import openvino
23-
from huggingface_hub import hf_hub_download
23+
from huggingface_hub import hf_hub_download, HfApi
2424
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
2525
from openvino._offline_transformations import apply_moc_transformations, compress_model_transformation
2626
from transformers import GenerationConfig, PretrainedConfig
@@ -362,3 +362,27 @@ def half(self):
362362

363363
def forward(self, *args, **kwargs):
364364
raise NotImplementedError
365+
366+
367+
@classmethod
368+
def _check_export_status(cls, model_id: Union[str, Path], revision: Optional[str] = None, subfolder: str = ""):
369+
model_dir = Path(model_id)
370+
if subfolder is not None:
371+
model_dir = model_dir / subfolder
372+
if model_dir.is_dir():
373+
encoder_exists = (model_dir / OV_ENCODER_NAME).exists() and (model_dir / OV_ENCODER_NAME.replace(".xml", ".bin")).exists()
374+
decoder_exists = (model_dir / OV_DECODER_NAME).exists() and (model_dir / OV_DECODER_NAME.replace(".xml", ".bin")).exists()
375+
return not encoder_exists or not decoder_exists
376+
377+
hf_api = HfApi()
378+
try:
379+
model_info = hf_api.model_info(model_id, revision=revision or "main")
380+
normalized_subfolder = None if subfolder is None else Path(subfolder).as_posix()
381+
model_files = [file.rfilename for file in model_info.siblings if normalized_subfolder is None or file.rfilename.startswith(normalized_subfolder)]
382+
for model_name in [OV_ENCODER_NAME, OV_DECODER_NAME]:
383+
ov_model_path = model_name if subfolder is None else f"{normalized_subfolder}/{model_name}"
384+
if not ov_model_path in model_files or not ov_model_path.replace(".xml", ".bin") in model_files:
385+
return True
386+
return False
387+
except Exception:
388+
return True

0 commit comments

Comments
 (0)