40
40
from transformers .models .auto .auto_factory import _get_model_class
41
41
from transformers .utils .generic import ContextManagers
42
42
43
+ from optimum .intel .generation import BaseModelForCausalLM
44
+
43
45
from ...modeling_base import OptimizedModel
44
46
from ..utils .import_utils import _torch_version , is_torch_version
45
47
from .configuration import INCConfig
@@ -83,11 +85,6 @@ def __init__(
83
85
"cuda:0" if torch .cuda .is_available () else "cpu"
84
86
)
85
87
86
- if getattr (self .config , "backend" , None ) == "ipex" :
87
- raise NotImplementedError (
88
- "`INCModel` does not supported the loading of model resulting from IPEX, please use `IPEXModel` to load your model instead instead"
89
- )
90
-
91
88
# Registers the INCModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
92
89
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
93
90
AutoConfig .register (self .base_model_prefix , AutoConfig )
@@ -143,11 +140,19 @@ def _from_pretrained(
143
140
f"Please check if torch quantization the model was obtained with is compatible with { _torch_version } ."
144
141
)
145
142
143
+ if getattr (config , "backend" , None ) == "ipex" or getattr (config , "torchscript" , False ):
144
+ logger .warning (
145
+ f"Using `{ cls .__name__ } ` to load a TorchScript model will be deprecated in v1.15.0, to load your model please use `{ cls .__name__ .replace ('INC' , 'IPEX' )} ` instead."
146
+ )
147
+ model = torch .jit .load (model_cache_path )
148
+ model = torch .jit .freeze (model .eval ())
149
+ return cls (model , config = config , model_save_dir = model_save_dir , inc_config = inc_config , ** kwargs )
150
+
146
151
model_class = _get_model_class (config , cls .auto_model_class ._model_mapping )
147
152
# Load the state dictionary of the model to verify whether the model to get the quantization config
148
153
state_dict = torch .load (model_cache_path , map_location = "cpu" )
149
- q_config = state_dict .get ("best_configure" , None )
150
154
155
+ q_config = state_dict .get ("best_configure" , None )
151
156
if q_config is None :
152
157
model = model_class .from_pretrained (model_save_dir )
153
158
else :
@@ -169,10 +174,13 @@ def _from_pretrained(
169
174
def _save_pretrained (self , save_directory : Union [str , Path ]):
170
175
output_path = os .path .join (save_directory , WEIGHTS_NAME )
171
176
172
- state_dict = self .model .state_dict ()
173
- if self ._q_config :
174
- state_dict ["best_configure" ] = self ._q_config
175
- torch .save (state_dict , output_path )
177
+ if isinstance (self .model , torch .nn .Module ):
178
+ state_dict = self .model .state_dict ()
179
+ if self ._q_config :
180
+ state_dict ["best_configure" ] = self ._q_config
181
+ torch .save (state_dict , output_path )
182
+ else :
183
+ torch .jit .save (self .model , output_path )
176
184
177
185
if self .inc_config :
178
186
self .inc_config .save_pretrained (save_directory )
@@ -244,6 +252,29 @@ class INCModelForXLNetLM(INCModel):
244
252
export_feature = "fill-mask"
245
253
246
254
247
- class INCModelForCausalLM (INCModel ):
255
+ class INCModelForCausalLM (INCModel , BaseModelForCausalLM ):
248
256
auto_model_class = AutoModelForCausalLM
249
257
export_feature = "text-generation"
258
+ forward = BaseModelForCausalLM .forward
259
+ generate = BaseModelForCausalLM .generate
260
+ can_generate = BaseModelForCausalLM .can_generate
261
+
262
+ def __init__ (
263
+ self ,
264
+ model ,
265
+ config : PretrainedConfig = None ,
266
+ model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
267
+ q_config : Dict = None ,
268
+ inc_config : Dict = None ,
269
+ use_cache : bool = True ,
270
+ ** kwargs ,
271
+ ):
272
+ super (INCModelForCausalLM , self ).__init__ (
273
+ model = model ,
274
+ config = config ,
275
+ model_save_dir = model_save_dir ,
276
+ q_config = q_config ,
277
+ inc_config = inc_config ,
278
+ use_cache = use_cache ,
279
+ ** kwargs ,
280
+ )
0 commit comments