Skip to content

Commit 3824300

Browse files
committedMay 24, 2024
debug beam search
1 parent 6289b57 commit 3824300

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed
 

‎optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,22 @@ def forward(
9393
(bs, seqlen, self.num_heads * self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device
9494
)
9595
key = torch.empty(
96-
(bs, prev_seqlen + seqlen, self.num_heads * self.head_dim),
96+
(bs, seqlen, self.num_heads * self.head_dim),
9797
dtype=hidden_states.dtype,
98-
device=hidden_states.device,
98+
device=hidden_states.device
9999
)
100100
value = torch.empty(
101-
(bs, prev_seqlen + seqlen, self.num_heads * self.head_dim),
101+
(bs, seqlen, self.num_heads * self.head_dim),
102102
dtype=hidden_states.dtype,
103-
device=hidden_states.device,
103+
device=hidden_states.device
104104
)
105105
torch.ops.torch_ipex.mm_qkv_out(
106106
hidden_states,
107107
self.qkv_proj_weight,
108108
self.qkv_proj_bias,
109109
query,
110-
key[:, prev_seqlen:, :],
111-
value[:, prev_seqlen:, :],
110+
key,
111+
value,
112112
)
113113
else:
114114
query = torch.empty(
@@ -125,21 +125,17 @@ def forward(
125125
)
126126

127127
query = query.view([bs, seqlen, self.num_heads, self.head_dim])
128-
key = key.view([bs, seqlen + prev_seqlen, self.num_kv_heads, self.head_dim])
128+
key = key.view([bs, seqlen, self.num_kv_heads, self.head_dim])
129129

130-
if hasattr(kwargs, "sin") and hasattr(kwargs, "cos"):
131-
print("cache sin cos")
132-
sin = kwargs["sin"]
133-
cos = kwargs["cos"]
134-
else:
135-
sin, cos = self.ipex_rope.get_sin_cos(seqlen, self.head_dim // 2)
136-
sin = sin.squeeze()[position_ids].unsqueeze(2)
137-
cos = cos.squeeze()[position_ids].unsqueeze(2)
138-
self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key[:, prev_seqlen:, :, :])
139-
value = value.view([bs, seqlen + prev_seqlen, self.num_kv_heads, self.head_dim])
130+
131+
sin = kwargs.pop("sin", None)
132+
cos = kwargs.pop("cos", None)
133+
134+
self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key)
135+
value = value.view([bs, seqlen, self.num_kv_heads, self.head_dim])
140136
if past_key_value is not None:
141-
value[:, :prev_seqlen, :, :] = past_key_value[1].transpose(1, 2)
142-
key[:, :prev_seqlen, :, :] = past_key_value[0].transpose(1, 2)
137+
key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1)
138+
value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1)
143139

144140
query = query.transpose(1, 2)
145141
key = key.transpose(1, 2)

0 commit comments

Comments
 (0)
Please sign in to comment.