Skip to content

Commit 3b627f4

Browse files
authored
Enable loading of torchscript model with INC and add warning (#540)
1 parent 1b5c3cb commit 3b627f4

File tree

2 files changed

+63
-11
lines changed

2 files changed

+63
-11
lines changed

optimum/intel/neural_compressor/modeling_base.py

+42-11
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from transformers.models.auto.auto_factory import _get_model_class
4141
from transformers.utils.generic import ContextManagers
4242

43+
from optimum.intel.generation import BaseModelForCausalLM
44+
4345
from ...modeling_base import OptimizedModel
4446
from ..utils.import_utils import _torch_version, is_torch_version
4547
from .configuration import INCConfig
@@ -83,11 +85,6 @@ def __init__(
8385
"cuda:0" if torch.cuda.is_available() else "cpu"
8486
)
8587

86-
if getattr(self.config, "backend", None) == "ipex":
87-
raise NotImplementedError(
88-
"`INCModel` does not supported the loading of model resulting from IPEX, please use `IPEXModel` to load your model instead instead"
89-
)
90-
9188
# Registers the INCModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
9289
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
9390
AutoConfig.register(self.base_model_prefix, AutoConfig)
@@ -143,11 +140,19 @@ def _from_pretrained(
143140
f"Please check if torch quantization the model was obtained with is compatible with {_torch_version}."
144141
)
145142

143+
if getattr(config, "backend", None) == "ipex" or getattr(config, "torchscript", False):
144+
logger.warning(
145+
f"Using `{cls.__name__}` to load a TorchScript model will be deprecated in v1.15.0, to load your model please use `{cls.__name__.replace('INC', 'IPEX')}` instead."
146+
)
147+
model = torch.jit.load(model_cache_path)
148+
model = torch.jit.freeze(model.eval())
149+
return cls(model, config=config, model_save_dir=model_save_dir, inc_config=inc_config, **kwargs)
150+
146151
model_class = _get_model_class(config, cls.auto_model_class._model_mapping)
147152
# Load the state dictionary of the model to verify whether the model to get the quantization config
148153
state_dict = torch.load(model_cache_path, map_location="cpu")
149-
q_config = state_dict.get("best_configure", None)
150154

155+
q_config = state_dict.get("best_configure", None)
151156
if q_config is None:
152157
model = model_class.from_pretrained(model_save_dir)
153158
else:
@@ -169,10 +174,13 @@ def _from_pretrained(
169174
def _save_pretrained(self, save_directory: Union[str, Path]):
170175
output_path = os.path.join(save_directory, WEIGHTS_NAME)
171176

172-
state_dict = self.model.state_dict()
173-
if self._q_config:
174-
state_dict["best_configure"] = self._q_config
175-
torch.save(state_dict, output_path)
177+
if isinstance(self.model, torch.nn.Module):
178+
state_dict = self.model.state_dict()
179+
if self._q_config:
180+
state_dict["best_configure"] = self._q_config
181+
torch.save(state_dict, output_path)
182+
else:
183+
torch.jit.save(self.model, output_path)
176184

177185
if self.inc_config:
178186
self.inc_config.save_pretrained(save_directory)
@@ -244,6 +252,29 @@ class INCModelForXLNetLM(INCModel):
244252
export_feature = "fill-mask"
245253

246254

247-
class INCModelForCausalLM(INCModel):
255+
class INCModelForCausalLM(INCModel, BaseModelForCausalLM):
248256
auto_model_class = AutoModelForCausalLM
249257
export_feature = "text-generation"
258+
forward = BaseModelForCausalLM.forward
259+
generate = BaseModelForCausalLM.generate
260+
can_generate = BaseModelForCausalLM.can_generate
261+
262+
def __init__(
263+
self,
264+
model,
265+
config: PretrainedConfig = None,
266+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
267+
q_config: Dict = None,
268+
inc_config: Dict = None,
269+
use_cache: bool = True,
270+
**kwargs,
271+
):
272+
super(INCModelForCausalLM, self).__init__(
273+
model=model,
274+
config=config,
275+
model_save_dir=model_save_dir,
276+
q_config=q_config,
277+
inc_config=inc_config,
278+
use_cache=use_cache,
279+
**kwargs,
280+
)

tests/neural_compressor/test_modeling.py

+21
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,24 @@ def test_pipeline(self, model_id, task):
122122
inputs *= 2
123123

124124
pipe(*inputs)
125+
126+
def test_compare_with_and_without_past_key_values(self):
127+
model_id = "echarlaix/tiny-random-gpt2-torchscript"
128+
tokenizer = AutoTokenizer.from_pretrained(model_id)
129+
tokens = tokenizer("This is a sample input", return_tensors="pt")
130+
131+
model_with_pkv = INCModelForCausalLM.from_pretrained(model_id, use_cache=True, subfolder="model_with_pkv")
132+
133+
outputs_with_pkv = model_with_pkv.generate(
134+
**tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
135+
)
136+
model_without_pkv = INCModelForCausalLM.from_pretrained(
137+
model_id, use_cache=False, subfolder="model_without_pkv"
138+
)
139+
140+
outputs_without_pkv = model_without_pkv.generate(
141+
**tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
142+
)
143+
self.assertEqual(outputs_with_pkv.shape[1], self.GENERATION_LENGTH)
144+
self.assertEqual(outputs_without_pkv.shape[1], self.GENERATION_LENGTH)
145+
self.assertTrue(torch.equal(outputs_with_pkv, outputs_without_pkv))

0 commit comments

Comments
 (0)