@@ -589,11 +589,10 @@ def forward_qdq(self, input):
589
589
return output
590
590
591
591
def forward_quant (self , input ):
592
- assert (
593
- input .shape [- 1 ] % self .world_size == 0
594
- ), "Please ensure that self.world_size is divisible by input.shape[-1]"
595
- input_shard = input .shape [- 1 ] // self .world_size
596
- splittedInput = input [:, :, self .rank * input_shard : (self .rank + 1 ) * input_shard ]
592
+ from deepspeed .module_inject .tp_shard import get_shard_size , get_shard_size_list
593
+ input_shard_size = get_shard_size (input .shape [- 1 ], self .world_size , "lm_head" )
594
+ input_shard_offset = sum (get_shard_size_list (input .shape [- 1 ], self .world_size , "lm_head" )[0 :self .rank ])
595
+ splittedInput = input [:, :, input_shard_offset :input_shard_offset + input_shard_size ]
597
596
qinput = self .quant_input (splittedInput )
598
597
output = self .matmul_fp8 (qinput ,
599
598
self .weight ,
@@ -611,11 +610,10 @@ def forward_quant(self, input):
611
610
return dqoutput
612
611
613
612
def forward_measure (self , input ):
614
- assert (
615
- input .shape [- 1 ] % self .world_size == 0
616
- ), "Please ensure that self.world_size is divisible by input.shape[-1]"
617
- input_shard = input .shape [- 1 ] // self .world_size
618
- splittedInput = input [:, :, self .rank * input_shard : (self .rank + 1 ) * input_shard ]
613
+ from deepspeed .module_inject .tp_shard import get_shard_size , get_shard_size_list
614
+ input_shard_size = get_shard_size (input .shape [- 1 ], self .world_size , "lm_head" )
615
+ input_shard_offset = sum (get_shard_size_list (input .shape [- 1 ], self .world_size , "lm_head" )[0 :self .rank ])
616
+ splittedInput = input [:, :, input_shard_offset :input_shard_offset + input_shard_size ]
619
617
measure_input ((splittedInput ,), observer = self ._mod_extra_config .inputs )
620
618
output = torch .matmul (splittedInput , self .weight .t ())
621
619
measure_output ((output ,), self ._mod_extra_config .outputs )
0 commit comments