Skip to content

Commit 4c1c636

Browse files
committed
fix llama
1 parent daabe80 commit 4c1c636

File tree

3 files changed

+74
-192
lines changed

3 files changed

+74
-192
lines changed

optimum/exporters/ipex/llama_functions.py

+26-54
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,7 @@
66
from torch import nn
77
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
88
from transformers.modeling_outputs import BaseModelOutputWithPast
9-
10-
11-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
12-
"""
13-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
14-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
15-
"""
16-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
17-
if n_rep == 1:
18-
return hidden_states
19-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
20-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
9+
from transformers.models.llama.modeling_llama import repeat_kv
2110

2211

2312
def llama_layer_norm_forward(self, hidden_states):
@@ -35,51 +24,34 @@ def llama_attn_forward(
3524
**kwargs,
3625
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
3726
bsz, q_len, _ = hidden_states.size()
38-
concat_qkv = None
39-
if hasattr(self, "concat_qkv") and self.concat_qkv is not None:
40-
concat_qkv = self.concat_qkv(hidden_states)
41-
else:
42-
query = self.q_proj(hidden_states)
43-
key = self.k_proj(hidden_states)
44-
value = self.v_proj(hidden_states)
27+
28+
query = self.q_proj(hidden_states)
29+
key = self.k_proj(hidden_states)
30+
value = self.v_proj(hidden_states)
4531

4632
kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len
4733

48-
if concat_qkv is not None and type(concat_qkv) is not tuple:
49-
query, key, value = self.ipex_rope(
50-
concat_qkv,
51-
position_ids,
52-
self.num_heads,
53-
self.head_dim,
54-
self.head_dim // 2,
55-
self.head_dim,
56-
kv_seq_len,
57-
self.concat_qkv._num_concats,
58-
)
59-
else:
60-
if concat_qkv is not None:
61-
query, key, value = concat_qkv
62-
query = query.view(bsz, q_len, self.num_heads, self.head_dim)
63-
key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
64-
value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
65-
key = self.ipex_rope(
66-
key,
67-
position_ids,
68-
self.num_key_value_heads,
69-
self.head_dim,
70-
self.head_dim // 2,
71-
self.head_dim,
72-
kv_seq_len,
73-
)
74-
query = self.ipex_rope(
75-
query,
76-
position_ids,
77-
self.num_heads,
78-
self.head_dim,
79-
self.head_dim // 2,
80-
self.head_dim,
81-
kv_seq_len,
82-
)
34+
query = query.view(bsz, q_len, self.num_heads, self.head_dim)
35+
key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
36+
value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
37+
key = self.ipex_rope(
38+
key,
39+
position_ids,
40+
self.num_key_value_heads,
41+
self.head_dim,
42+
self.head_dim // 2,
43+
self.head_dim,
44+
kv_seq_len,
45+
)
46+
query = self.ipex_rope(
47+
query,
48+
position_ids,
49+
self.num_heads,
50+
self.head_dim,
51+
self.head_dim // 2,
52+
self.head_dim,
53+
kv_seq_len,
54+
)
8355

8456
if use_cache:
8557
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(

optimum/exporters/ipex/model_patcher.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from intel_extension_for_pytorch.llm.modules import ApplyRotaryEmbedding, IndirectKVCache
1+
from intel_extension_for_pytorch.llm.modules import ApplyRotaryEmbedding, IndirectAccessKVCache
22
from transformers.models.llama.modeling_llama import (
33
LlamaAttention,
44
LlamaDecoderLayer,
@@ -50,7 +50,7 @@ def export_llama_model(model):
5050
model.config.rope_theta,
5151
model.config.architectures[0],
5252
)
53-
ipex_scale_dot_product = IndirectKVCache(text_max_length=model.config.max_position_embeddings)
53+
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)
5454
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
5555
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)
5656

@@ -65,7 +65,5 @@ def export_llama_model(model):
6565

6666
def export_model(model):
6767
if isinstance(model, LlamaForCausalLM):
68-
export_llama_model(model)
69-
return True
70-
71-
return False
68+
model = export_llama_model(model)
69+
return model

optimum/intel/ipex/modeling_base.py

+44-132
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@
5656
logger = logging.getLogger(__name__)
5757

5858

59+
IPEX_EXPORTED_LIST = ("LlamaForCausalLM", )
60+
61+
62+
def is_ipex_exported_model(model_name):
63+
for name in IPEX_EXPORTED_LIST:
64+
if model_name == name:
65+
return True
66+
return False
67+
68+
5969
def ipex_jit_trace(model):
6070
sample_inputs = get_dummy_input(model, return_dict=True)
6171
model.config.return_dict = False
@@ -138,8 +148,9 @@ def _from_transformers(
138148
}
139149

140150
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
141-
is_ipex_exported = export_model(model)
151+
is_ipex_exported = is_ipex_exported_model(model.__class__.__name__)
142152
if is_ipex_exported:
153+
model = export_model(model)
143154
traced_model = ipex_jit_trace(model)
144155
else:
145156
model = patch_decoder_attention_mask(model)
@@ -199,6 +210,8 @@ def _from_pretrained(
199210

200211
model = torch.jit.load(model_cache_path)
201212
torch.jit.freeze(model.eval())
213+
is_ipex_exported = is_ipex_exported_model(model.original_name)
214+
kwargs["is_ipex_exported"] = is_ipex_exported
202215

203216
return cls(model, config=config, model_save_dir=model_save_dir, **kwargs)
204217

@@ -379,12 +392,15 @@ def __init__(
379392
except AttributeError:
380393
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)
381394

382-
self._reorder_cache = self.model_cls._reorder_cache
395+
if self.is_ipex_exported:
396+
self._reorder_cache = _ipex_reorder_cache
397+
else:
398+
self._reorder_cache = self.model_cls._reorder_cache.__get__(self)
383399

384400
if is_transformers_version(">=", "4.38.0") and model_type in {"llama", "phi", "persimmon"}:
385401
self.prepare_inputs_for_generation = _prepare_inputs_for_generation_for_llama
386402
else:
387-
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation
403+
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)
388404

389405
if hasattr(self.model_cls, "_convert_to_standard_cache"):
390406
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
@@ -393,37 +409,6 @@ def __init__(
393409
if warmup:
394410
self._init_warmup()
395411

396-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
397-
past_key_values = past_key_values or kwargs.get("past", None)
398-
399-
if self.use_cache and past_key_values is not None:
400-
input_ids = input_ids[:, -1:]
401-
402-
# `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed
403-
if past_key_values is not None and self.config.model_type == "bloom":
404-
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
405-
past_key_values = self._convert_to_bloom_cache(past_key_values)
406-
407-
position_ids = kwargs.get("position_ids", None)
408-
409-
attention_mask = kwargs.get("attention_mask", None)
410-
411-
if attention_mask is not None and position_ids is None:
412-
# create position_ids on the fly for batch generation
413-
position_ids = attention_mask.long().cumsum(-1) - 1
414-
position_ids.masked_fill_(attention_mask == 0, 1)
415-
if past_key_values:
416-
position_ids = position_ids[:, -1].unsqueeze(-1)
417-
418-
return {
419-
"input_ids": input_ids,
420-
"past_key_values": past_key_values,
421-
"use_cache": self.use_cache,
422-
"position_ids": position_ids,
423-
"attention_mask": attention_mask,
424-
"token_type_ids": None,
425-
}
426-
427412
def _prepare_past_key_values(self, input_ids):
428413
model_type = self.config.model_type.replace("_", "-")
429414
nb_pkv = 2
@@ -469,104 +454,6 @@ def _prepare_past_key_values(self, input_ids):
469454

470455
return past_key_values
471456

472-
def _reorder_cache(
473-
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
474-
) -> Tuple[Tuple[torch.Tensor]]:
475-
"""
476-
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
477-
[`~PreTrainedModel.beam_sample`] is called.
478-
This is required to match `past_key_values` with the correct beam_idx at every generation step.
479-
"""
480-
if self.config.model_type == "bloom":
481-
return self._reorder_cache_bloom(past_key_values, beam_idx)
482-
483-
if self.is_ipex_exported:
484-
if len(past_key_values[0]) == 4 and past_key_values[0][0].shape[-1] == 1: # discrete kv_cache
485-
for layer_past in past_key_values:
486-
layer_past[3][layer_past[0].size(-2) - 1] = beam_idx
487-
return past_key_values
488-
elif len(past_key_values[0]) == 8:
489-
for layer_past in past_key_values:
490-
layer_past[3][layer_past[0].size(-2) - 1] = beam_idx
491-
layer_past[7][layer_past[0].size(-2) - 1] = beam_idx
492-
return past_key_values
493-
else:
494-
return tuple(
495-
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
496-
for layer_past in past_key_values
497-
)
498-
# from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
499-
return tuple(
500-
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
501-
for layer_past in past_key_values
502-
)
503-
504-
# Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
505-
def _reorder_cache_bloom(
506-
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
507-
) -> Tuple[Tuple[torch.Tensor]]:
508-
"""
509-
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
510-
[`~PreTrainedModel.beam_sample`] is called for bloom architecture.
511-
This is required to match `past_key_values` with the correct beam_idx at every generation step.
512-
"""
513-
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))
514-
515-
# Get a copy of `beam_idx` on all the devices where we need those indices.
516-
device_to_beam_idx = {
517-
past_state.device: beam_idx.to(past_state.device)
518-
for layer_past in past_key_values
519-
for past_state in layer_past
520-
}
521-
reordered_past = tuple(
522-
(
523-
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
524-
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
525-
)
526-
for layer_past in standardized_past
527-
)
528-
return self._convert_to_bloom_cache(reordered_past)
529-
530-
# Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache
531-
@staticmethod
532-
def _convert_to_bloom_cache(past_key_value: Tuple[Tuple[torch.Tensor]]) -> Tuple[Tuple[torch.Tensor]]:
533-
"""
534-
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
535-
"""
536-
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
537-
batch_size_times_num_heads = batch_size * num_heads
538-
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
539-
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
540-
return tuple(
541-
(
542-
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
543-
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
544-
)
545-
for layer_past in past_key_value
546-
)
547-
548-
# Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache
549-
def _convert_to_standard_cache(
550-
self, past_key_value: Tuple[Tuple[torch.Tensor]], batch_size: int
551-
) -> Tuple[Tuple[torch.Tensor]]:
552-
"""
553-
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...]))
554-
"""
555-
if self.config.model_type != "bloom":
556-
return past_key_value
557-
558-
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
559-
num_heads = batch_size_times_num_heads // batch_size
560-
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
561-
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
562-
return tuple(
563-
(
564-
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
565-
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
566-
)
567-
for layer_past in past_key_value
568-
)
569-
570457
def forward(
571458
self,
572459
input_ids: torch.LongTensor = None,
@@ -670,3 +557,28 @@ def _prepare_inputs_for_generation_for_llama(
670557
}
671558
)
672559
return model_inputs
560+
561+
562+
def _ipex_reorder_cache(
563+
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
564+
) -> Tuple[Tuple[torch.Tensor]]:
565+
566+
if len(past_key_values[0]) == 4 and past_key_values[0][0].shape[-1] == 1: # discrete kv_cache
567+
for layer_past in past_key_values:
568+
layer_past[3][layer_past[0].size(-2) - 1] = beam_idx
569+
return past_key_values
570+
elif len(past_key_values[0]) == 8:
571+
for layer_past in past_key_values:
572+
layer_past[3][layer_past[0].size(-2) - 1] = beam_idx
573+
layer_past[7][layer_past[0].size(-2) - 1] = beam_idx
574+
return past_key_values
575+
else:
576+
return tuple(
577+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
578+
for layer_past in past_key_values
579+
)
580+
581+
return tuple(
582+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
583+
for layer_past in past_key_values
584+
)

0 commit comments

Comments
 (0)