@@ -96,53 +96,8 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module:
96
96
)
97
97
return model
98
98
else :
99
- # only show logs of error level, since keys_to_ignore_on_load_unexpected is not working without specific model_class
100
- transformers .logging .set_verbosity_error ()
101
- if not os .path .isdir (model_name_or_path ) and not os .path .isfile (model_name_or_path ): # pragma: no cover
102
- from transformers .utils import cached_file
103
-
104
- try :
105
- # Load from URL or cache if already cached
106
- resolved_weights_file = cached_file (
107
- model_name_or_path ,
108
- filename = WEIGHTS_NAME ,
109
- cache_dir = cache_dir ,
110
- force_download = force_download ,
111
- resume_download = resume_download ,
112
- use_auth_token = use_auth_token ,
113
- )
114
- except EnvironmentError as err : # pragma: no cover
115
- logger .error (err )
116
- msg = (
117
- f"Can't load weights for '{ model_name_or_path } '. Make sure that:\n \n "
118
- f"- '{ model_name_or_path } ' is a correct model identifier "
119
- f"listed on 'https://huggingface.co/models'\n (make sure "
120
- f"'{ model_name_or_path } ' is not a path to a local directory with "
121
- f"something else, in that case)\n \n - or '{ model_name_or_path } ' is "
122
- f"the correct path to a directory containing a file "
123
- f"named one of { WEIGHTS_NAME } \n \n "
124
- )
125
- if revision is not None :
126
- msg += (
127
- f"- or '{ revision } ' is a valid git identifier "
128
- f"(branch name, a tag name, or a commit id) that "
129
- f"exists for this model name as listed on its model "
130
- f"page on 'https://huggingface.co/models'\n \n "
131
- )
132
- raise EnvironmentError (msg )
133
- else :
134
- resolved_weights_file = os .path .join (model_name_or_path , WEIGHTS_NAME )
135
- state_dict = torch .load (resolved_weights_file , {})
136
- model = model_class .from_pretrained (
137
- model_name_or_path ,
138
- cache_dir = cache_dir ,
139
- force_download = force_download ,
140
- resume_download = resume_download ,
141
- use_auth_token = use_auth_token ,
142
- revision = revision ,
143
- state_dict = state_dict ,
144
- ** kwargs ,
145
- )
99
+ config .torch_dtype = torch .float32
100
+ model = model_class .from_config (config )
146
101
147
102
if not os .path .isdir (model_name_or_path ) and not os .path .isfile (model_name_or_path ): # pragma: no cover
148
103
# pylint: disable=E0611
0 commit comments