@@ -90,6 +90,8 @@ def ipex_jit_trace(model, task, use_cache):
90
90
model .config .return_dict = False
91
91
92
92
model = ipex .optimize (model .eval (), dtype = model .dtype , inplace = True )
93
+ # Disable repack while jit tracing to reduce the memory
94
+ ipex ._C .disable_jit_linear_repack ()
93
95
with torch .no_grad ():
94
96
trace_model = torch .jit .trace (
95
97
model ,
@@ -171,23 +173,10 @@ def _from_transformers(
171
173
model = TasksManager .get_model_from_task (task , model_id , ** model_kwargs )
172
174
traced_model = ipex_jit_trace (model , task , use_cache )
173
175
174
- save_dir = TemporaryDirectory ()
175
- save_dir_path = Path (save_dir .name )
176
- torch .jit .save (traced_model , save_dir_path / WEIGHTS_NAME )
177
176
config .torchscript = True
178
177
config .torch_dtype = torch_dtype
179
178
180
- return cls ._from_pretrained (
181
- model_id = save_dir_path ,
182
- config = config ,
183
- use_auth_token = use_auth_token ,
184
- revision = revision ,
185
- force_download = force_download ,
186
- cache_dir = cache_dir ,
187
- local_files_only = local_files_only ,
188
- use_cache = use_cache ,
189
- model_dtype = torch_dtype ,
190
- )
179
+ return cls (traced_model , config = config , model_save_dir = model_id , use_cache = use_cache )
191
180
192
181
@classmethod
193
182
def _from_pretrained (
0 commit comments