diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index a125194250..0a82fe20b5 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -90,6 +90,8 @@ def ipex_jit_trace(model, task, use_cache): model.config.return_dict = False model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True) + # Disable repack while jit tracing to reduce the memory + ipex._C.disable_jit_linear_repack() with torch.no_grad(): trace_model = torch.jit.trace( model, @@ -171,23 +173,10 @@ def _from_transformers( model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) traced_model = ipex_jit_trace(model, task, use_cache) - save_dir = TemporaryDirectory() - save_dir_path = Path(save_dir.name) - torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME) config.torchscript = True config.torch_dtype = torch_dtype - return cls._from_pretrained( - model_id=save_dir_path, - config=config, - use_auth_token=use_auth_token, - revision=revision, - force_download=force_download, - cache_dir=cache_dir, - local_files_only=local_files_only, - use_cache=use_cache, - model_dtype=torch_dtype, - ) + return cls(traced_model, config=config, model_save_dir=model_id, use_cache=use_cache, warmup=False) @classmethod def _from_pretrained(