Skip to content

Commit e4cc4e6

Browse files
committed
fix jit memory issue
1 parent 9af1b7c commit e4cc4e6

File tree

1 file changed

+3
-14
lines changed

1 file changed

+3
-14
lines changed

optimum/intel/ipex/modeling_base.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def ipex_jit_trace(model, task, use_cache):
9090
model.config.return_dict = False
9191

9292
model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True)
93+
# Disable repack while jit tracing to reduce the memory
94+
ipex._C.disable_jit_linear_repack()
9395
with torch.no_grad():
9496
trace_model = torch.jit.trace(
9597
model,
@@ -171,23 +173,10 @@ def _from_transformers(
171173
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
172174
traced_model = ipex_jit_trace(model, task, use_cache)
173175

174-
save_dir = TemporaryDirectory()
175-
save_dir_path = Path(save_dir.name)
176-
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
177176
config.torchscript = True
178177
config.torch_dtype = torch_dtype
179178

180-
return cls._from_pretrained(
181-
model_id=save_dir_path,
182-
config=config,
183-
use_auth_token=use_auth_token,
184-
revision=revision,
185-
force_download=force_download,
186-
cache_dir=cache_dir,
187-
local_files_only=local_files_only,
188-
use_cache=use_cache,
189-
model_dtype=torch_dtype,
190-
)
179+
return cls(traced_model, config=config, model_save_dir=model_id, use_cache=use_cache)
191180

192181
@classmethod
193182
def _from_pretrained(

0 commit comments

Comments
 (0)