Skip to content

Commit d98e9ba

Browse files
fix
1 parent 79316ee commit d98e9ba

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

tests/fx/parallelization/test_tensor_parallel.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"output_attentions": False,
3737
"output_hidden_states": False,
3838
"tie_word_embeddings": True,
39+
"return_dict": True,
3940
}
4041

4142
DUMMY_MODELS_TO_TEST = (
@@ -64,11 +65,10 @@ def prepare_dummy_inputs(
6465
seq_len: int = 10,
6566
device: Union[str, torch.device] = "cuda",
6667
):
67-
return {
68-
"input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(batch_size, seq_len), device=device),
69-
"attention_mask": torch.ones((batch_size, seq_len), dtype=torch.int64, device=device),
70-
"position_ids": torch.arange(0, seq_len, device=device).unsqueeze(0).expand(batch_size, -1),
71-
}
68+
input_ids = torch.randint(low=1, high=model_config.vocab_size, size=(batch_size, seq_len), device=device)
69+
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
70+
labels = input_ids.clone()
71+
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
7272

7373

7474
def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, model_kwargs: Dict[str, Any]):
@@ -82,8 +82,8 @@ def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, m
8282

8383
model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs)
8484
inputs = prepare_dummy_inputs(model.config)
85-
logits = model(**inputs)[0]
86-
tensors = gather_at_main_process(tensor=logits, group=tp_group, rank=rank, world_size=world_size)
85+
loss = model(**inputs).loss
86+
tensors = gather_at_main_process(tensor=loss, group=tp_group, rank=rank, world_size=world_size)
8787

8888
# check results at main worker process
8989
if rank == 0:
@@ -145,7 +145,7 @@ def run_test_parallel_results_matches_non_parallel(
145145
inputs = prepare_dummy_inputs(model.config)
146146

147147
set_seed(SEED)
148-
logits = model(**inputs)[0]
148+
loss = model(**inputs).loss
149149

150150
torch._dynamo.reset()
151151
del model
@@ -154,9 +154,9 @@ def run_test_parallel_results_matches_non_parallel(
154154
set_seed(SEED)
155155
ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device)
156156
model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs)
157-
parallel_logits = model(**inputs)[0]
157+
parallel_loss = model(**inputs).loss
158158

159-
torch.testing.assert_close(logits.cpu(), parallel_logits.cpu(), rtol=1e-4, atol=1e-4)
159+
torch.testing.assert_close(loss.cpu(), parallel_loss.cpu(), rtol=1e-4, atol=1e-4)
160160

161161
dist.barrier(tp_group)
162162
tearDown()

0 commit comments

Comments
 (0)