Skip to content

Commit f52d7c8

Browse files
authored
Fix loading of INC quantized model (#452)
1 parent b14059f commit f52d7c8

File tree

1 file changed

+7
-40
lines changed

1 file changed

+7
-40
lines changed

optimum/intel/neural_compressor/modeling_base.py

+7-40
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import copy
1615
import logging
1716
import os
1817
from pathlib import Path
@@ -138,7 +137,6 @@ def _from_pretrained(
138137

139138
model_save_dir = Path(model_cache_path).parent
140139
inc_config = None
141-
q_config = None
142140
msg = None
143141
try:
144142
inc_config = INCConfig.from_pretrained(model_id)
@@ -153,54 +151,23 @@ def _from_pretrained(
153151
# load(model_cache_path)
154152
model = torch.jit.load(model_cache_path)
155153
model = torch.jit.freeze(model.eval())
156-
return cls(model, config=config, model_save_dir=model_save_dir, **kwargs)
154+
return cls(model, config=config, model_save_dir=model_save_dir, inc_config=inc_config, **kwargs)
157155

158156
model_class = _get_model_class(config, cls.auto_model_class._model_mapping)
159-
keys_to_ignore_on_load_unexpected = copy.deepcopy(
160-
getattr(model_class, "_keys_to_ignore_on_load_unexpected", None)
161-
)
162-
keys_to_ignore_on_load_missing = copy.deepcopy(getattr(model_class, "_keys_to_ignore_on_load_missing", None))
163-
# Avoid unnecessary warnings resulting from quantized model initialization
164-
quantized_keys_to_ignore_on_load = [
165-
r"zero_point",
166-
r"scale",
167-
r"packed_params",
168-
r"constant",
169-
r"module",
170-
r"best_configure",
171-
r"max_val",
172-
r"min_val",
173-
r"eps",
174-
r"fake_quant_enabled",
175-
r"observer_enabled",
176-
]
177-
if keys_to_ignore_on_load_unexpected is None:
178-
model_class._keys_to_ignore_on_load_unexpected = quantized_keys_to_ignore_on_load
179-
else:
180-
model_class._keys_to_ignore_on_load_unexpected.extend(quantized_keys_to_ignore_on_load)
181-
missing_keys_to_ignore_on_load = [r"weight", r"bias"]
182-
if keys_to_ignore_on_load_missing is None:
183-
model_class._keys_to_ignore_on_load_missing = missing_keys_to_ignore_on_load
184-
else:
185-
model_class._keys_to_ignore_on_load_missing.extend(missing_keys_to_ignore_on_load)
157+
# Load the state dictionary of the model to verify whether the model to get the quantization config
158+
state_dict = torch.load(model_cache_path, map_location="cpu")
159+
q_config = state_dict.get("best_configure", None)
186160

187-
try:
161+
if q_config is None:
188162
model = model_class.from_pretrained(model_save_dir)
189-
except AttributeError:
163+
else:
190164
init_contexts = [no_init_weights(_enable=True)]
191165
with ContextManagers(init_contexts):
192166
model = model_class(config)
193-
194-
model_class._keys_to_ignore_on_load_unexpected = keys_to_ignore_on_load_unexpected
195-
model_class._keys_to_ignore_on_load_missing = keys_to_ignore_on_load_missing
196-
197-
# Load the state dictionary of the model to verify whether the model is quantized or not
198-
state_dict = torch.load(model_cache_path, map_location="cpu")
199-
if "best_configure" in state_dict and state_dict["best_configure"] is not None:
200-
q_config = state_dict["best_configure"]
201167
try:
202168
model = load(model_cache_path, model)
203169
except Exception as e:
170+
# For incompatible torch version check
204171
if msg is not None:
205172
e.args += (msg,)
206173
raise

0 commit comments

Comments
 (0)