Skip to content

Commit 512d5c6

Browse files
Add argument library_name when parameters standartization (#2179)
* avoid library_name guessing if it is known in parameters standartization * Update optimum/subpackages.py --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
1 parent 856b252 commit 512d5c6

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

optimum/exporters/tasks.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -2067,7 +2067,11 @@ def infer_library_from_model(
20672067
return library_name
20682068

20692069
@classmethod
2070-
def standardize_model_attributes(cls, model: Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]):
2070+
def standardize_model_attributes(
2071+
cls,
2072+
model: Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"],
2073+
library_name: Optional[str] = None,
2074+
):
20712075
"""
20722076
Updates the model for export. This function is suitable to make required changes to the models from different
20732077
libraries to follow transformers style.
@@ -2078,7 +2082,8 @@ def standardize_model_attributes(cls, model: Union["PreTrainedModel", "TFPreTrai
20782082
20792083
"""
20802084

2081-
library_name = TasksManager.infer_library_from_model(model)
2085+
if library_name is None:
2086+
library_name = TasksManager.infer_library_from_model(model)
20822087

20832088
if library_name == "diffusers":
20842089
inferred_model_type = None
@@ -2295,7 +2300,7 @@ def get_model_from_task(
22952300
kwargs["from_pt"] = True
22962301
model = model_class.from_pretrained(model_name_or_path, **kwargs)
22972302

2298-
TasksManager.standardize_model_attributes(model)
2303+
TasksManager.standardize_model_attributes(model, library_name=library_name)
22992304

23002305
return model
23012306

optimum/subpackages.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ def load_namespace_modules(namespace: str, module: str):
4646
"""
4747
for dist in importlib_metadata.distributions():
4848
dist_name = dist.metadata["Name"]
49-
if not dist_name.startswith(f"{namespace}-"):
49+
if dist_name is None:
5050
continue
5151
if dist_name == f"{namespace}-benchmark":
5252
continue
53+
if not dist_name.startswith(f"{namespace}-"):
54+
continue
5355
package_import_name = dist_name.replace("-", ".")
5456
module_import_name = f"{package_import_name}.{module}"
5557
if module_import_name in sys.modules:

0 commit comments

Comments
 (0)