@@ -346,8 +346,8 @@ def _falcon_model_forward(
346
346
347
347
# Prepare head mask if needed
348
348
# 1.0 in head_mask indicate we keep the head
349
- # attention_probs has shape batch_size x num_heads x N x N
350
- # head_mask has shape n_layer x batch x num_heads x N x N
349
+ # attention_probs has shape batch_size x num_attention_heads x N x N
350
+ # head_mask has shape n_layer x batch x num_attention_heads x N x N
351
351
head_mask = self .get_head_mask (head_mask , self .config .num_hidden_layers )
352
352
hidden_states = inputs_embeds
353
353
@@ -707,7 +707,9 @@ def __init__(self, module, device, config) -> None:
707
707
_setattr_from_module (self , module )
708
708
self .config = config
709
709
self .module_device = device
710
- self .num_groups = self .num_heads // self .num_key_value_heads
710
+ self .num_key_value_heads = config .num_key_value_heads
711
+ self .num_attention_heads = config .num_attention_heads
712
+ self .num_groups = self .num_attention_heads // self .num_key_value_heads
711
713
self .kv_head_mapping = torch .arange (
712
714
0 , self .num_key_value_heads , dtype = torch .int32 , device = self .module_device
713
715
).repeat_interleave (self .num_groups )
@@ -892,11 +894,11 @@ def __init__(self, module, device, config) -> None:
892
894
def qkv_gemm (self , hidden_states ):
893
895
if hasattr (self , "concat_qkv" ):
894
896
qkv_out = self .concat_qkv (hidden_states )
895
- query = qkv_out [:, : self .q_slice ].view (- 1 , self .num_heads , self .head_dim )
897
+ query = qkv_out [:, : self .q_slice ].view (- 1 , self .num_attention_heads , self .head_dim )
896
898
key = qkv_out [:, self .q_slice : self .k_slice ].view (- 1 , self .num_key_value_heads , self .head_dim )
897
899
value = qkv_out [:, self .k_slice :].view (- 1 , self .num_key_value_heads , self .head_dim )
898
900
else :
899
- query = self .q_proj (hidden_states ).view (- 1 , self .num_heads , self .head_dim )
901
+ query = self .q_proj (hidden_states ).view (- 1 , self .num_attention_heads , self .head_dim )
900
902
key = self .k_proj (hidden_states ).view (- 1 , self .num_key_value_heads , self .head_dim )
901
903
value = self .v_proj (hidden_states ).view (- 1 , self .num_key_value_heads , self .head_dim )
902
904
@@ -914,20 +916,21 @@ def __init__(self, module, device, config):
914
916
def qkv_gemm (self , hidden_states ):
915
917
qkv_out = self .query_key_value (hidden_states )
916
918
if self .new_decoder_architecture :
917
- qkv_out = qkv_out .view (qkv_out .shape [0 ], - 1 , self .num_heads // self .num_kv_heads + 2 , self .head_dim )
919
+ qkv_out = qkv_out .view (
920
+ qkv_out .shape [0 ], - 1 , self .num_attention_heads // self .num_kv_heads + 2 , self .head_dim
921
+ )
918
922
query = qkv_out [:, :, :- 2 , :].flatten (1 , 2 )
919
923
key = qkv_out [:, :, [- 2 ], :].flatten (1 , 2 )
920
924
value = qkv_out [:, :, [- 1 ], :].flatten (1 , 2 )
921
925
else :
922
- query = qkv_out [:, : self .q_slice ].view (- 1 , self .num_heads , self .head_dim )
926
+ query = qkv_out [:, : self .q_slice ].view (- 1 , self .num_attention_heads , self .head_dim )
923
927
key = qkv_out [:, self .q_slice : self .k_slice ].view (- 1 , self .num_key_value_heads , self .head_dim )
924
928
value = qkv_out [:, self .k_slice :].view (- 1 , self .num_key_value_heads , self .head_dim )
925
929
return query , key , value
926
930
927
931
928
932
class _IPEXGPT2Attention (_IPEXAttention ):
929
933
def __init__ (self , module , device , config ) -> None :
930
- self .num_key_value_heads = config .num_key_value_heads
931
934
super ().__init__ (module , device , config )
932
935
_setattr_from_module (self , module )
933
936
if not config .compile and getattr (config , "quantization_config" , None ) is None :
@@ -950,9 +953,9 @@ def qkv_gemm(self, hidden_states):
950
953
query , key , value = self .c_attn_linear (hidden_states ).split (self .split_size , dim = - 1 )
951
954
else :
952
955
query , key , value = self .c_attn (hidden_states ).split (self .split_size , dim = - 1 )
953
- query = query .view (- 1 , self .num_heads , self .head_dim )
954
- key = key .view (- 1 , self .num_heads , self .head_dim )
955
- value = value .view (- 1 , self .num_heads , self .head_dim )
956
+ query = query .view (- 1 , self .num_attention_heads , self .head_dim )
957
+ key = key .view (- 1 , self .num_attention_heads , self .head_dim )
958
+ value = value .view (- 1 , self .num_attention_heads , self .head_dim )
956
959
return query , key , value
957
960
958
961
def rope (self , query , key , * args , ** kwargs ):
0 commit comments