Skip to content

Commit aa8da34

Browse files
committed
change a better way to resolve it
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent 608af90 commit aa8da34

File tree

1 file changed

+2
-47
lines changed

1 file changed

+2
-47
lines changed

neural_compressor/utils/load_huggingface.py

+2-47
Original file line numberDiff line numberDiff line change
@@ -96,53 +96,8 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module:
9696
)
9797
return model
9898
else:
99-
# only show logs of error level, since keys_to_ignore_on_load_unexpected is not working without specific model_class
100-
transformers.logging.set_verbosity_error()
101-
if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path): # pragma: no cover
102-
from transformers.utils import cached_file
103-
104-
try:
105-
# Load from URL or cache if already cached
106-
resolved_weights_file = cached_file(
107-
model_name_or_path,
108-
filename=WEIGHTS_NAME,
109-
cache_dir=cache_dir,
110-
force_download=force_download,
111-
resume_download=resume_download,
112-
use_auth_token=use_auth_token,
113-
)
114-
except EnvironmentError as err: # pragma: no cover
115-
logger.error(err)
116-
msg = (
117-
f"Can't load weights for '{model_name_or_path}'. Make sure that:\n\n"
118-
f"- '{model_name_or_path}' is a correct model identifier "
119-
f"listed on 'https://huggingface.co/models'\n (make sure "
120-
f"'{model_name_or_path}' is not a path to a local directory with "
121-
f"something else, in that case)\n\n- or '{model_name_or_path}' is "
122-
f"the correct path to a directory containing a file "
123-
f"named one of {WEIGHTS_NAME}\n\n"
124-
)
125-
if revision is not None:
126-
msg += (
127-
f"- or '{revision}' is a valid git identifier "
128-
f"(branch name, a tag name, or a commit id) that "
129-
f"exists for this model name as listed on its model "
130-
f"page on 'https://huggingface.co/models'\n\n"
131-
)
132-
raise EnvironmentError(msg)
133-
else:
134-
resolved_weights_file = os.path.join(model_name_or_path, WEIGHTS_NAME)
135-
state_dict = torch.load(resolved_weights_file, {})
136-
model = model_class.from_pretrained(
137-
model_name_or_path,
138-
cache_dir=cache_dir,
139-
force_download=force_download,
140-
resume_download=resume_download,
141-
use_auth_token=use_auth_token,
142-
revision=revision,
143-
state_dict=state_dict,
144-
**kwargs,
145-
)
99+
config.torch_dtype = torch.float32
100+
model = model_class.from_config(config)
146101

147102
if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path): # pragma: no cover
148103
# pylint: disable=E0611

0 commit comments

Comments
 (0)