@@ -614,22 +614,6 @@ def forward(
614
614
if past_len == 0 :
615
615
# prefill, remove padding
616
616
seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
617
- # varlen_attention(
618
- # query.contiguous() if query.device.type == "xpu" else query,
619
- # key.contiguous() if key.device.type == "xpu" else key,
620
- # value.contiguous() if value.device.type == "xpu" else value,
621
- # attn_output,
622
- # seq_len_tensor,
623
- # seq_len_tensor,
624
- # input_lens.max(),
625
- # input_lens.max(),
626
- # 0.0,
627
- # 1.0 / math.sqrt(self.head_dim),
628
- # False,
629
- # True,
630
- # False,
631
- # None,
632
- # )
633
617
PagedAttention .flash_attn_varlen_func (
634
618
attn_output ,
635
619
query ,
@@ -734,9 +718,16 @@ class _IPEXGPT2Attention(_IPEXAttention):
734
718
def __init__ (self , module , config ) -> None :
735
719
self .num_key_value_heads = config .num_key_value_heads
736
720
super ().__init__ (module , config )
721
+ _setattr_from_module (self , module )
722
+ self .c_attn_linear = nn .Linear (self .c_attn .weight .shape [0 ], self .c_attn .weight .shape [1 ])
723
+ self .c_attn_linear .weight = nn .Parameter (self .c_attn .weight .t ())
724
+ self .c_attn_linear .bias = self .c_attn .bias
725
+ self .c_proj_linear = nn .Linear (self .c_proj .weight .shape [0 ], self .c_proj .weight .shape [1 ])
726
+ self .c_proj_linear .weight = nn .Parameter (self .c_proj .weight .t ())
727
+ self .c_proj_linear .bias = self .c_proj .bias
737
728
738
729
def qkv_gemm (self , hidden_states ):
739
- query , key , value = self .c_attn (hidden_states ).split (self .split_size , dim = - 1 )
730
+ query , key , value = self .c_attn_linear (hidden_states ).split (self .split_size , dim = - 1 )
740
731
query = query .view (- 1 , self .num_heads , self .head_dim )
741
732
key = key .view (- 1 , self .num_heads , self .head_dim )
742
733
value = value .view (- 1 , self .num_heads , self .head_dim )
@@ -748,7 +739,6 @@ def rope(self, query, key, *args, **kwargs):
748
739
def postprocess_attention_output (self , attn_output ):
749
740
attn_output = attn_output .reshape (- 1 , attn_output .shape [- 2 ] * attn_output .shape [- 1 ])
750
741
attn_output = self .c_proj (attn_output )
751
- attn_output = self .resid_dropout (attn_output )
752
742
return attn_output
753
743
754
744
0 commit comments