@@ -346,8 +346,8 @@ def _falcon_model_forward(
346
346
347
347
# Prepare head mask if needed
348
348
# 1.0 in head_mask indicate we keep the head
349
- # attention_probs has shape batch_size x num_heads x N x N
350
- # head_mask has shape n_layer x batch x num_heads x N x N
349
+ # attention_probs has shape batch_size x num_attention_heads x N x N
350
+ # head_mask has shape n_layer x batch x num_attention_heads x N x N
351
351
head_mask = self .get_head_mask (head_mask , self .config .num_hidden_layers )
352
352
hidden_states = inputs_embeds
353
353
@@ -707,7 +707,9 @@ def __init__(self, module, device, config) -> None:
707
707
_setattr_from_module (self , module )
708
708
self .config = config
709
709
self .module_device = device
710
- self .num_groups = self .num_heads // self .num_key_value_heads
710
+ self .num_key_value_heads = config .num_key_value_heads
711
+ self .num_attention_heads = config .num_attention_heads
712
+ self .num_groups = self .num_attention_heads // self .num_key_value_heads
711
713
self .kv_head_mapping = torch .arange (
712
714
0 , self .num_key_value_heads , dtype = torch .int32 , device = self .module_device
713
715
).repeat_interleave (self .num_groups )
@@ -894,11 +896,11 @@ def __init__(self, module, device, config) -> None:
894
896
def qkv_gemm (self , hidden_states ):
895
897
if hasattr (self , "concat_qkv" ):
896
898
qkv_out = self .concat_qkv (hidden_states )
897
- query = qkv_out [:, : self .q_slice ].view (- 1 , self .num_heads , self .head_dim )
899
+ query = qkv_out [:, : self .q_slice ].view (- 1 , self .num_attention_heads , self .head_dim )
898
900
key = qkv_out [:, self .q_slice : self .k_slice ].view (- 1 , self .num_key_value_heads , self .head_dim )
899
901
value = qkv_out [:, self .k_slice :].view (- 1 , self .num_key_value_heads , self .head_dim )
900
902
else :
901
- query = self .q_proj (hidden_states ).view (- 1 , self .num_heads , self .head_dim )
903
+ query = self .q_proj (hidden_states ).view (- 1 , self .num_attention_heads , self .head_dim )
902
904
key = self .k_proj (hidden_states ).view (- 1 , self .num_key_value_heads , self .head_dim )
903
905
value = self .v_proj (hidden_states ).view (- 1 , self .num_key_value_heads , self .head_dim )
904
906
@@ -916,20 +918,21 @@ def __init__(self, module, device, config):
916
918
def qkv_gemm (self , hidden_states ):
917
919
qkv_out = self .query_key_value (hidden_states )
918
920
if self .new_decoder_architecture :
919
- qkv_out = qkv_out .view (qkv_out .shape [0 ], - 1 , self .num_heads // self .num_kv_heads + 2 , self .head_dim )
921
+ qkv_out = qkv_out .view (
922
+ qkv_out .shape [0 ], - 1 , self .num_attention_heads // self .num_kv_heads + 2 , self .head_dim
923
+ )
920
924
query = qkv_out [:, :, :- 2 , :].flatten (1 , 2 )
921
925
key = qkv_out [:, :, [- 2 ], :].flatten (1 , 2 )
922
926
value = qkv_out [:, :, [- 1 ], :].flatten (1 , 2 )
923
927
else :
924
- query = qkv_out [:, : self .q_slice ].view (- 1 , self .num_heads , self .head_dim )
928
+ query = qkv_out [:, : self .q_slice ].view (- 1 , self .num_attention_heads , self .head_dim )
925
929
key = qkv_out [:, self .q_slice : self .k_slice ].view (- 1 , self .num_key_value_heads , self .head_dim )
926
930
value = qkv_out [:, self .k_slice :].view (- 1 , self .num_key_value_heads , self .head_dim )
927
931
return query , key , value
928
932
929
933
930
934
class _IPEXGPT2Attention (_IPEXAttention ):
931
935
def __init__ (self , module , device , config ) -> None :
932
- self .num_key_value_heads = config .num_key_value_heads
933
936
super ().__init__ (module , device , config )
934
937
_setattr_from_module (self , module )
935
938
if getattr (config , "quantization_config" , None ) is None :
@@ -952,9 +955,9 @@ def qkv_gemm(self, hidden_states):
952
955
query , key , value = self .c_attn_linear (hidden_states ).split (self .split_size , dim = - 1 )
953
956
else :
954
957
query , key , value = self .c_attn (hidden_states ).split (self .split_size , dim = - 1 )
955
- query = query .view (- 1 , self .num_heads , self .head_dim )
956
- key = key .view (- 1 , self .num_heads , self .head_dim )
957
- value = value .view (- 1 , self .num_heads , self .head_dim )
958
+ query = query .view (- 1 , self .num_attention_heads , self .head_dim )
959
+ key = key .view (- 1 , self .num_attention_heads , self .head_dim )
960
+ value = value .view (- 1 , self .num_attention_heads , self .head_dim )
958
961
return query , key , value
959
962
960
963
def rope (self , query , key , * args , ** kwargs ):
0 commit comments