@@ -2467,3 +2467,134 @@ def patched_forward(*args, **kwargs):
2467
2467
return outputs
2468
2468
2469
2469
self .patched_forward = patched_forward
2470
+
2471
+
2472
+ def _decilm_attn_forward (
2473
+ self ,
2474
+ hidden_states : torch .Tensor ,
2475
+ attention_mask : Optional [torch .Tensor ] = None ,
2476
+ position_ids : Optional [torch .LongTensor ] = None ,
2477
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
2478
+ output_attentions : bool = False ,
2479
+ use_cache : bool = False ,
2480
+ ** kwargs ,
2481
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
2482
+ # decilm contains bug in attention calculation for case if past key values is not None
2483
+ def rotate_half (x ):
2484
+ """Rotates half the hidden dims of the input."""
2485
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
2486
+ x2 = x [..., x .shape [- 1 ] // 2 :]
2487
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
2488
+
2489
+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
2490
+ """Applies Rotary Position Embedding to the query and key tensors.
2491
+
2492
+ Args:
2493
+ q (`torch.Tensor`): The query tensor.
2494
+ k (`torch.Tensor`): The key tensor.
2495
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
2496
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
2497
+ position_ids (`torch.Tensor`):
2498
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
2499
+ used to pass offsetted position ids when working with a KV-cache.
2500
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
2501
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
2502
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
2503
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
2504
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
2505
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
2506
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
2507
+ Returns:
2508
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
2509
+ """
2510
+ cos = cos [position_ids ].unsqueeze (unsqueeze_dim )
2511
+ sin = sin [position_ids ].unsqueeze (unsqueeze_dim )
2512
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
2513
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
2514
+ return q_embed , k_embed
2515
+
2516
+ def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
2517
+ """
2518
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
2519
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
2520
+ """
2521
+ batch , num_key_value_heads , slen , head_dim = hidden_states .shape
2522
+ if n_rep == 1 :
2523
+ return hidden_states
2524
+ hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_key_value_heads , n_rep , slen , head_dim )
2525
+ return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
2526
+
2527
+ bsz , q_len , _ = hidden_states .size ()
2528
+ if self .pretraining_tp > 1 :
2529
+ key_value_slicing = (self .num_key_value_heads * self .head_dim ) // self .pretraining_tp
2530
+ query_slices = self .q_proj .weight .split ((self .num_heads * self .head_dim ) // self .pretraining_tp , dim = 0 )
2531
+ key_slices = self .k_proj .weight .split (key_value_slicing , dim = 0 )
2532
+ value_slices = self .v_proj .weight .split (key_value_slicing , dim = 0 )
2533
+
2534
+ query_states = [F .linear (hidden_states , query_slices [i ]) for i in range (self .pretraining_tp )]
2535
+ query_states = torch .cat (query_states , dim = - 1 )
2536
+
2537
+ key_states = [F .linear (hidden_states , key_slices [i ]) for i in range (self .pretraining_tp )]
2538
+ key_states = torch .cat (key_states , dim = - 1 )
2539
+
2540
+ value_states = [F .linear (hidden_states , value_slices [i ]) for i in range (self .pretraining_tp )]
2541
+ value_states = torch .cat (value_states , dim = - 1 )
2542
+
2543
+ else :
2544
+ query_states = self .q_proj (hidden_states )
2545
+ key_states = self .k_proj (hidden_states )
2546
+ value_states = self .v_proj (hidden_states )
2547
+
2548
+ query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
2549
+ key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
2550
+ value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
2551
+
2552
+ kv_seq_len = key_states .shape [- 2 ]
2553
+ if past_key_value is not None :
2554
+ kv_seq_len += past_key_value [0 ].shape [- 2 ]
2555
+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
2556
+
2557
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
2558
+
2559
+ if past_key_value is not None :
2560
+ # reuse k, v, self_attention
2561
+ key_states = torch .cat ([past_key_value [0 ], key_states ], dim = 2 )
2562
+ value_states = torch .cat ([past_key_value [1 ], value_states ], dim = 2 )
2563
+
2564
+ past_key_value = (key_states , value_states ) if use_cache else None
2565
+
2566
+ # repeat k/v heads if n_kv_heads < n_heads
2567
+ key_states = repeat_kv (key_states , self .num_key_value_groups )
2568
+ value_states = repeat_kv (value_states , self .num_key_value_groups )
2569
+ attn_output = F .scaled_dot_product_attention (
2570
+ query_states , key_states , value_states , is_causal = attention_mask is None , attn_mask = attention_mask
2571
+ )
2572
+
2573
+ # modified, in original implementation .transpose(1, 2) missed
2574
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ().view (bsz , q_len , self .hidden_size )
2575
+
2576
+ if self .pretraining_tp > 1 :
2577
+ attn_output = attn_output .split (self .hidden_size // self .pretraining_tp , dim = 2 )
2578
+ o_proj_slices = self .o_proj .weight .split (self .hidden_size // self .pretraining_tp , dim = 1 )
2579
+ attn_output = sum ([F .linear (attn_output [i ], o_proj_slices [i ]) for i in range (self .pretraining_tp )])
2580
+ else :
2581
+ attn_output = self .o_proj (attn_output )
2582
+
2583
+ attn_weights = None
2584
+
2585
+ return attn_output , attn_weights , past_key_value
2586
+
2587
+
2588
+ class DeciLMModelPatcher (DecoderModelPatcher ):
2589
+ def __enter__ (self ):
2590
+ super ().__enter__ ()
2591
+
2592
+ for layer in self ._model .model .layers :
2593
+ layer .self_attn ._orig_forward = layer .self_attn .forward
2594
+ layer .self_attn .forward = types .MethodType (_decilm_attn_forward , layer .self_attn )
2595
+
2596
+ def __exit__ (self , exc_type , exc_value , traceback ):
2597
+ super ().__exit__ (exc_type , exc_value , traceback )
2598
+
2599
+ for layer in self ._model .model .layers :
2600
+ layer .self_attn .forward = layer .self_attn ._orig_forward
0 commit comments