12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import copy
16
15
import logging
17
16
import os
18
17
from pathlib import Path
@@ -138,7 +137,6 @@ def _from_pretrained(
138
137
139
138
model_save_dir = Path (model_cache_path ).parent
140
139
inc_config = None
141
- q_config = None
142
140
msg = None
143
141
try :
144
142
inc_config = INCConfig .from_pretrained (model_id )
@@ -153,54 +151,23 @@ def _from_pretrained(
153
151
# load(model_cache_path)
154
152
model = torch .jit .load (model_cache_path )
155
153
model = torch .jit .freeze (model .eval ())
156
- return cls (model , config = config , model_save_dir = model_save_dir , ** kwargs )
154
+ return cls (model , config = config , model_save_dir = model_save_dir , inc_config = inc_config , ** kwargs )
157
155
158
156
model_class = _get_model_class (config , cls .auto_model_class ._model_mapping )
159
- keys_to_ignore_on_load_unexpected = copy .deepcopy (
160
- getattr (model_class , "_keys_to_ignore_on_load_unexpected" , None )
161
- )
162
- keys_to_ignore_on_load_missing = copy .deepcopy (getattr (model_class , "_keys_to_ignore_on_load_missing" , None ))
163
- # Avoid unnecessary warnings resulting from quantized model initialization
164
- quantized_keys_to_ignore_on_load = [
165
- r"zero_point" ,
166
- r"scale" ,
167
- r"packed_params" ,
168
- r"constant" ,
169
- r"module" ,
170
- r"best_configure" ,
171
- r"max_val" ,
172
- r"min_val" ,
173
- r"eps" ,
174
- r"fake_quant_enabled" ,
175
- r"observer_enabled" ,
176
- ]
177
- if keys_to_ignore_on_load_unexpected is None :
178
- model_class ._keys_to_ignore_on_load_unexpected = quantized_keys_to_ignore_on_load
179
- else :
180
- model_class ._keys_to_ignore_on_load_unexpected .extend (quantized_keys_to_ignore_on_load )
181
- missing_keys_to_ignore_on_load = [r"weight" , r"bias" ]
182
- if keys_to_ignore_on_load_missing is None :
183
- model_class ._keys_to_ignore_on_load_missing = missing_keys_to_ignore_on_load
184
- else :
185
- model_class ._keys_to_ignore_on_load_missing .extend (missing_keys_to_ignore_on_load )
157
+ # Load the state dictionary of the model to verify whether the model to get the quantization config
158
+ state_dict = torch .load (model_cache_path , map_location = "cpu" )
159
+ q_config = state_dict .get ("best_configure" , None )
186
160
187
- try :
161
+ if q_config is None :
188
162
model = model_class .from_pretrained (model_save_dir )
189
- except AttributeError :
163
+ else :
190
164
init_contexts = [no_init_weights (_enable = True )]
191
165
with ContextManagers (init_contexts ):
192
166
model = model_class (config )
193
-
194
- model_class ._keys_to_ignore_on_load_unexpected = keys_to_ignore_on_load_unexpected
195
- model_class ._keys_to_ignore_on_load_missing = keys_to_ignore_on_load_missing
196
-
197
- # Load the state dictionary of the model to verify whether the model is quantized or not
198
- state_dict = torch .load (model_cache_path , map_location = "cpu" )
199
- if "best_configure" in state_dict and state_dict ["best_configure" ] is not None :
200
- q_config = state_dict ["best_configure" ]
201
167
try :
202
168
model = load (model_cache_path , model )
203
169
except Exception as e :
170
+ # For incompatible torch version check
204
171
if msg is not None :
205
172
e .args += (msg ,)
206
173
raise
0 commit comments