@@ -114,6 +114,7 @@ def __init__(
114
114
self .routed_scaling_factor = config .routed_scaling_factor
115
115
self .n_shared_experts = config .n_shared_experts
116
116
self .routed_scaling_factor = config .routed_scaling_factor
117
+ self ._prefix = prefix
117
118
if self .tp_size > config .n_routed_experts :
118
119
raise ValueError (
119
120
f"Tensor parallel size { self .tp_size } is greater than "
@@ -164,6 +165,7 @@ def __init__(
164
165
165
166
166
167
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
168
+ # show_mem_info(logger, f"{self._prefix}: before gate")
167
169
batch_size , seq_len , hidden_dim = hidden_states .shape
168
170
num_tokens = batch_size * seq_len
169
171
hidden_states = hidden_states .view (- 1 , hidden_dim )
@@ -172,15 +174,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
172
174
# router_logits: (num_tokens, n_experts)
173
175
router_logits , _ = self .gate (hidden_states )
174
176
hidden_states = hidden_states .reshape (batch_size , seq_len , hidden_dim )
177
+ # show_mem_info(logger, f"{self._prefix}: shared_output shape {shared_output.shape}, router_logits shape {router_logits.shape}, hidden_states shape {hidden_states.shape}")
178
+ # show_mem_info(logger, f"{self._prefix}: before experts")
175
179
final_hidden_states = self .experts (
176
180
hidden_states = hidden_states ,
177
181
router_logits = router_logits ) * self .routed_scaling_factor
182
+ # show_mem_info(logger, f"{self._prefix}: after experts")
178
183
if shared_output is not None :
179
184
final_hidden_states = final_hidden_states + shared_output
180
185
if self .ep_size == 1 and self .tp_size > 1 :
181
186
final_hidden_states = tensor_model_parallel_all_reduce (
182
187
final_hidden_states )
183
-
188
+ # show_mem_info(logger, f"{self._prefix}: before return")
184
189
return final_hidden_states .view (batch_size , seq_len , hidden_dim )
185
190
186
191
@@ -536,6 +541,7 @@ def __init__(
536
541
# DecoderLayers are created with `make_layers` which passes the prefix
537
542
# with the layer's index.
538
543
layer_idx = int (prefix .split (sep = '.' )[- 1 ])
544
+ self ._prefix = prefix
539
545
if model_config .use_mla :
540
546
attn_cls = DeepseekV3MLAAttention
541
547
else :
@@ -594,20 +600,20 @@ def forward(
594
600
hidden_states , residual = self .input_layernorm (
595
601
hidden_states , residual )
596
602
# logger.info(f"hidden_states shape : {hidden_states.shape}")
597
- # show_mem_info(logger, "DeepseekV3DecoderLayer : before self_attn")
603
+ # show_mem_info(logger, f"{self._prefix} : before self_attn")
598
604
hidden_states = self .self_attn (
599
605
positions = positions ,
600
606
hidden_states = hidden_states ,
601
607
kv_cache = kv_cache ,
602
608
attn_metadata = attn_metadata ,
603
609
)
604
- # show_mem_info(logger, "DeepseekV3DecoderLayer: after self_attn" )
605
- htorch . core . mark_step ( )
610
+ # htorch.core.mark_step( )
611
+ # show_mem_info(logger, f"{self._prefix}: after self_attn" )
606
612
# Fully Connected
607
613
hidden_states , residual = self .post_attention_layernorm (
608
614
hidden_states , residual )
609
615
hidden_states = self .mlp (hidden_states )
610
- # show_mem_info(logger, "DeepseekV3DecoderLayer : after mlp")
616
+ # show_mem_info(logger, f"{self._prefix} : after mlp")
611
617
return hidden_states , residual
612
618
613
619
0 commit comments