Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 38ed051

Browse files
committedFeb 28, 2024·
fix comments
1 parent b04b435 commit 38ed051

File tree

4 files changed

+65
-76
lines changed

4 files changed

+65
-76
lines changed
 

‎optimum/exporters/ipex/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
from .model_patcher import export_model

‎optimum/exporters/ipex/llama_functions.py

+14-40
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import math
216
from typing import List, Optional, Tuple, Union
317

@@ -96,46 +110,6 @@ def llama_attn_forward(
96110
return attn_output, attn_weights, past_key_value
97111

98112

99-
def prepare_inputs_for_generation(
100-
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
101-
):
102-
if past_key_values is not None:
103-
past_length = past_key_values[0][0].shape[2]
104-
105-
# Some generation methods already pass only the last input ID
106-
if input_ids.shape[1] > past_length:
107-
remove_prefix_length = past_length
108-
else:
109-
# Default to old behavior: keep only final ID
110-
remove_prefix_length = input_ids.shape[1] - 1
111-
112-
input_ids = input_ids[:, remove_prefix_length:]
113-
114-
position_ids = kwargs.get("position_ids", None)
115-
if attention_mask is not None and position_ids is None:
116-
# create position_ids on the fly for batch generation
117-
position_ids = attention_mask.long().cumsum(-1) - 1
118-
position_ids.masked_fill_(attention_mask == 0, 1)
119-
if past_key_values:
120-
position_ids = position_ids[:, -input_ids.shape[1] :]
121-
122-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
123-
if inputs_embeds is not None and past_key_values is None:
124-
model_inputs = {"inputs_embeds": inputs_embeds}
125-
else:
126-
model_inputs = {"input_ids": input_ids}
127-
128-
model_inputs.update(
129-
{
130-
"position_ids": position_ids,
131-
"past_key_values": past_key_values,
132-
"use_cache": kwargs.get("use_cache"),
133-
"attention_mask": attention_mask,
134-
}
135-
)
136-
return model_inputs
137-
138-
139113
def llama_model_forward(
140114
self,
141115
input_ids: torch.LongTensor = None,

‎optimum/exporters/ipex/model_patcher.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
from intel_extension_for_pytorch.llm.modules import ApplyRotaryEmbedding, IndirectAccessKVCache
216
from transformers.models.llama.modeling_llama import (
317
LlamaAttention,
@@ -12,10 +26,13 @@
1226
llama_attn_forward,
1327
llama_layer_norm_forward,
1428
llama_model_forward,
15-
prepare_inputs_for_generation,
1629
)
1730

1831

32+
IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",)
33+
IPEX_EXPORTED_TASK = ("text-generation",)
34+
35+
1936
def convert_func(m, func_name, new_function):
2037
bound_method = new_function.__get__(m, m.__class__)
2138
setattr(m, func_name, bound_method)
@@ -43,7 +60,7 @@ def patch_op(m, target_m, new_op_name, new_op):
4360
patch_op(sub_m, target_m, new_op_name, new_op)
4461

4562

46-
def export_llama_model(model):
63+
def _patch_llama_model(model):
4764
ipex_rope = ApplyRotaryEmbedding(
4865
model.config.max_position_embeddings,
4966
model.config.hidden_size // model.config.num_attention_heads,
@@ -54,7 +71,6 @@ def export_llama_model(model):
5471
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
5572
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)
5673

57-
convert_func(model, "prepare_inputs_for_generation", prepare_inputs_for_generation)
5874
convert_functions(model, LlamaModel, "forward", llama_model_forward)
5975
convert_functions(model, LlamaAttention, "forward", llama_attn_forward)
6076
convert_functions(model, LlamaRMSNorm, "forward", llama_layer_norm_forward)
@@ -63,7 +79,7 @@ def export_llama_model(model):
6379
return model
6480

6581

66-
def export_model(model):
82+
def _patch_model(model):
6783
if isinstance(model, LlamaForCausalLM):
68-
model = export_llama_model(model)
69-
return model
84+
model = _patch_llama_model(model)
85+
return model

‎optimum/intel/ipex/modeling_base.py

+29-29
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from huggingface_hub import hf_hub_download
2525
from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp
2626
from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
27+
from packaging import version
2728
from transformers import (
2829
AutoConfig,
2930
AutoModel,
@@ -47,7 +48,7 @@
4748
from optimum.modeling_base import OptimizedModel
4849
from optimum.utils import NormalizedConfigManager
4950

50-
from ...exporters.ipex import export_model
51+
from ...exporters.ipex.model_patcher import IPEX_EXPORTED_ARCH, IPEX_EXPORTED_TASK, _patch_model
5152
from ..generation.modeling import jit_trace, prepare_jit_inputs
5253
from ..utils.import_utils import is_torch_version, is_transformers_version
5354
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask
@@ -56,17 +57,28 @@
5657
logger = logging.getLogger(__name__)
5758

5859

59-
IPEX_EXPORTED_LIST = ("LlamaForCausalLM", )
60+
IPEX_SUPPORT_MODEL_TYPES = "llama"
6061

6162

62-
def is_ipex_exported_model(model_name):
63-
for name in IPEX_EXPORTED_LIST:
64-
if model_name == name:
65-
return True
66-
return False
63+
def is_model_support_ipex_export(model, task):
64+
if isinstance(model, torch.jit.ScriptModule):
65+
is_ipex_exported = model.original_name in IPEX_EXPORTED_ARCH
66+
else:
67+
is_ipex_exported = model.config.model_type in IPEX_SUPPORT_MODEL_TYPES and task in IPEX_EXPORTED_TASK
68+
69+
return is_ipex_exported
70+
71+
72+
def ipex_jit_trace(model, task, use_cache):
73+
if version.parse(ipex.__version__) <= version.parse("2.3.0") or not is_model_support_ipex_export(model, task):
74+
model = patch_decoder_attention_mask(model)
75+
model = ipex.optimize(model, dtype=model.dtype, level="O1", auto_kernel_selection=True)
76+
return jit_trace(model, task, use_cache)
6777

78+
if is_torch_version("<", "2.1.0"):
79+
raise ImportError("`torch>=2.1.0` is needed to trace your model")
6880

69-
def ipex_jit_trace(model):
81+
model = _patch_model(model)
7082
sample_inputs = get_dummy_input(model, return_dict=True)
7183
model.config.return_dict = False
7284
_enable_tpp()
@@ -104,7 +116,7 @@ def __init__(
104116
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
105117
self.model.to(self._device)
106118
self.model_save_dir = model_save_dir
107-
self.is_ipex_exported = kwargs.get("is_ipex_exported", None)
119+
self.is_ipex_exported = is_model_support_ipex_export(model, self.export_feature)
108120

109121
self.input_names = {
110122
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
@@ -148,14 +160,7 @@ def _from_transformers(
148160
}
149161

150162
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
151-
is_ipex_exported = is_ipex_exported_model(model.__class__.__name__)
152-
if is_ipex_exported:
153-
model = export_model(model)
154-
traced_model = ipex_jit_trace(model)
155-
else:
156-
model = patch_decoder_attention_mask(model)
157-
model = ipex.optimize(model, dtype=torch_dtype, level="O1", auto_kernel_selection=True)
158-
traced_model = jit_trace(model, task, use_cache)
163+
traced_model = ipex_jit_trace(model, task, use_cache)
159164

160165
save_dir = TemporaryDirectory()
161166
save_dir_path = Path(save_dir.name)
@@ -173,7 +178,6 @@ def _from_transformers(
173178
local_files_only=local_files_only,
174179
use_cache=use_cache,
175180
model_dtype=torch_dtype,
176-
is_ipex_exported=is_ipex_exported,
177181
)
178182

179183
@classmethod
@@ -210,8 +214,6 @@ def _from_pretrained(
210214

211215
model = torch.jit.load(model_cache_path)
212216
torch.jit.freeze(model.eval())
213-
is_ipex_exported = is_ipex_exported_model(model.original_name)
214-
kwargs["is_ipex_exported"] = is_ipex_exported
215217

216218
return cls(model, config=config, model_save_dir=model_save_dir, **kwargs)
217219

@@ -372,7 +374,6 @@ def __init__(
372374
model_type = config.model_type.replace("_", "-")
373375
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config)
374376
self.use_cache = "past_key_values" in self.input_names
375-
self.is_ipex_exported = kwargs.get("is_ipex_exported", None)
376377

377378
if use_cache ^ self.use_cache:
378379
raise ValueError(
@@ -422,7 +423,11 @@ def _prepare_past_key_values(self, input_ids):
422423
num_attention_heads = self.normalized_config.num_attention_heads
423424

424425
if self.is_ipex_exported:
425-
beam_idx_tmp = torch.zeros((2048, input_ids.shape[0]), dtype=torch.long).contiguous()
426+
# Indirect access kv cache has a different data layout compared with most transformers model,
427+
# see https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/llm.html#indirect-access-kv-cache
428+
beam_idx_tmp = torch.zeros(
429+
(self.config.max_position_embeddings, input_ids.shape[0]), dtype=torch.long
430+
).contiguous()
426431
past_key_values = tuple(
427432
[
428433
(
@@ -562,8 +567,8 @@ def _prepare_inputs_for_generation_for_llama(
562567
def _ipex_reorder_cache(
563568
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
564569
) -> Tuple[Tuple[torch.Tensor]]:
565-
566-
if len(past_key_values[0]) == 4 and past_key_values[0][0].shape[-1] == 1: # discrete kv_cache
570+
# Ipex patched model uses indirect access kv cache which has a different shape with other transformers models
571+
if len(past_key_values[0]) == 4 and past_key_values[0][0].shape[-1] == 1:
567572
for layer_past in past_key_values:
568573
layer_past[3][layer_past[0].size(-2) - 1] = beam_idx
569574
return past_key_values
@@ -577,8 +582,3 @@ def _ipex_reorder_cache(
577582
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
578583
for layer_past in past_key_values
579584
)
580-
581-
return tuple(
582-
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
583-
for layer_past in past_key_values
584-
)

0 commit comments

Comments
 (0)
Please sign in to comment.