Skip to content

Commit 9ceb958

Browse files
[ALGO-809] PatchedLmHeadLinearAllreduce: replacing the sharding code with the one from deepspeed-fork (#85)
Change-Id: Icb9670cfefdd1880c1ebb9a804a97c9ba79ecdc3 Co-authored-by: smarkovichgolan <smarkovich@habana.ai>
1 parent 795b716 commit 9ceb958

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -589,11 +589,10 @@ def forward_qdq(self, input):
589589
return output
590590

591591
def forward_quant(self, input):
592-
assert (
593-
input.shape[-1] % self.world_size == 0
594-
), "Please ensure that self.world_size is divisible by input.shape[-1]"
595-
input_shard = input.shape[-1] // self.world_size
596-
splittedInput = input[:, :, self.rank * input_shard : (self.rank + 1) * input_shard]
592+
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
593+
input_shard_size = get_shard_size(input.shape[-1], self.world_size, "lm_head")
594+
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size, "lm_head")[0:self.rank])
595+
splittedInput = input[:, :, input_shard_offset:input_shard_offset + input_shard_size]
597596
qinput = self.quant_input(splittedInput)
598597
output = self.matmul_fp8(qinput,
599598
self.weight,
@@ -611,11 +610,10 @@ def forward_quant(self, input):
611610
return dqoutput
612611

613612
def forward_measure(self, input):
614-
assert (
615-
input.shape[-1] % self.world_size == 0
616-
), "Please ensure that self.world_size is divisible by input.shape[-1]"
617-
input_shard = input.shape[-1] // self.world_size
618-
splittedInput = input[:, :, self.rank * input_shard : (self.rank + 1) * input_shard]
613+
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
614+
input_shard_size = get_shard_size(input.shape[-1], self.world_size, "lm_head")
615+
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size, "lm_head")[0:self.rank])
616+
splittedInput = input[:, :, input_shard_offset:input_shard_offset + input_shard_size]
619617
measure_input((splittedInput,), observer=self._mod_extra_config.inputs)
620618
output = torch.matmul(splittedInput, self.weight.t())
621619
measure_output((output,), self._mod_extra_config.outputs)

0 commit comments

Comments
 (0)