@@ -93,22 +93,22 @@ def forward(
93
93
(bs , seqlen , self .num_heads * self .head_dim ), dtype = hidden_states .dtype , device = hidden_states .device
94
94
)
95
95
key = torch .empty (
96
- (bs , prev_seqlen + seqlen , self .num_heads * self .head_dim ),
96
+ (bs , seqlen , self .num_heads * self .head_dim ),
97
97
dtype = hidden_states .dtype ,
98
- device = hidden_states .device ,
98
+ device = hidden_states .device
99
99
)
100
100
value = torch .empty (
101
- (bs , prev_seqlen + seqlen , self .num_heads * self .head_dim ),
101
+ (bs , seqlen , self .num_heads * self .head_dim ),
102
102
dtype = hidden_states .dtype ,
103
- device = hidden_states .device ,
103
+ device = hidden_states .device
104
104
)
105
105
torch .ops .torch_ipex .mm_qkv_out (
106
106
hidden_states ,
107
107
self .qkv_proj_weight ,
108
108
self .qkv_proj_bias ,
109
109
query ,
110
- key [:, prev_seqlen :, :] ,
111
- value [:, prev_seqlen :, :] ,
110
+ key ,
111
+ value ,
112
112
)
113
113
else :
114
114
query = torch .empty (
@@ -125,21 +125,17 @@ def forward(
125
125
)
126
126
127
127
query = query .view ([bs , seqlen , self .num_heads , self .head_dim ])
128
- key = key .view ([bs , seqlen + prev_seqlen , self .num_kv_heads , self .head_dim ])
128
+ key = key .view ([bs , seqlen , self .num_kv_heads , self .head_dim ])
129
129
130
- if hasattr (kwargs , "sin" ) and hasattr (kwargs , "cos" ):
131
- print ("cache sin cos" )
132
- sin = kwargs ["sin" ]
133
- cos = kwargs ["cos" ]
134
- else :
135
- sin , cos = self .ipex_rope .get_sin_cos (seqlen , self .head_dim // 2 )
136
- sin = sin .squeeze ()[position_ids ].unsqueeze (2 )
137
- cos = cos .squeeze ()[position_ids ].unsqueeze (2 )
138
- self .ipex_rope .apply_embedding (query , sin , cos , self .head_dim // 2 , key [:, prev_seqlen :, :, :])
139
- value = value .view ([bs , seqlen + prev_seqlen , self .num_kv_heads , self .head_dim ])
130
+
131
+ sin = kwargs .pop ("sin" , None )
132
+ cos = kwargs .pop ("cos" , None )
133
+
134
+ self .ipex_rope .apply_embedding (query , sin , cos , self .head_dim // 2 , key )
135
+ value = value .view ([bs , seqlen , self .num_kv_heads , self .head_dim ])
140
136
if past_key_value is not None :
141
- value [:, : prev_seqlen , :, :] = past_key_value [1 ].transpose (1 , 2 )
142
- key [:, : prev_seqlen , :, :] = past_key_value [0 ].transpose (1 , 2 )
137
+ key = torch . cat ([ past_key_value [0 ].transpose (1 , 2 ), key ], dim = 1 )
138
+ value = torch . cat ([ past_key_value [1 ].transpose (1 , 2 ), value ], dim = 1 )
143
139
144
140
query = query .transpose (1 , 2 )
145
141
key = key .transpose (1 , 2 )
0 commit comments