From e4cc4e61a1a94752365b26526862a52bc0c3c0f3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 22 Mar 2024 09:22:36 -0400 Subject: [PATCH 1/2] fix jit memory issue --- optimum/intel/ipex/modeling_base.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index a125194250..f8aee38ec1 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -90,6 +90,8 @@ def ipex_jit_trace(model, task, use_cache): model.config.return_dict = False model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True) + # Disable repack while jit tracing to reduce the memory + ipex._C.disable_jit_linear_repack() with torch.no_grad(): trace_model = torch.jit.trace( model, @@ -171,23 +173,10 @@ def _from_transformers( model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) traced_model = ipex_jit_trace(model, task, use_cache) - save_dir = TemporaryDirectory() - save_dir_path = Path(save_dir.name) - torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME) config.torchscript = True config.torch_dtype = torch_dtype - return cls._from_pretrained( - model_id=save_dir_path, - config=config, - use_auth_token=use_auth_token, - revision=revision, - force_download=force_download, - cache_dir=cache_dir, - local_files_only=local_files_only, - use_cache=use_cache, - model_dtype=torch_dtype, - ) + return cls(traced_model, config=config, model_save_dir=model_id, use_cache=use_cache) @classmethod def _from_pretrained( From a3fb5b8319f27172b43d9fd9f0d20bebf3403a64 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Tue, 26 Mar 2024 13:00:26 +0800 Subject: [PATCH 2/2] Update optimum/intel/ipex/modeling_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> --- optimum/intel/ipex/modeling_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index f8aee38ec1..0a82fe20b5 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -176,7 +176,7 @@ def _from_transformers( config.torchscript = True config.torch_dtype = torch_dtype - return cls(traced_model, config=config, model_save_dir=model_id, use_cache=use_cache) + return cls(traced_model, config=config, model_save_dir=model_id, use_cache=use_cache, warmup=False) @classmethod def _from_pretrained(