36
36
"output_attentions" : False ,
37
37
"output_hidden_states" : False ,
38
38
"tie_word_embeddings" : True ,
39
+ "return_dict" : True ,
39
40
}
40
41
41
42
DUMMY_MODELS_TO_TEST = (
@@ -64,11 +65,10 @@ def prepare_dummy_inputs(
64
65
seq_len : int = 10 ,
65
66
device : Union [str , torch .device ] = "cuda" ,
66
67
):
67
- return {
68
- "input_ids" : torch .randint (low = 1 , high = model_config .vocab_size , size = (batch_size , seq_len ), device = device ),
69
- "attention_mask" : torch .ones ((batch_size , seq_len ), dtype = torch .int64 , device = device ),
70
- "position_ids" : torch .arange (0 , seq_len , device = device ).unsqueeze (0 ).expand (batch_size , - 1 ),
71
- }
68
+ input_ids = torch .randint (low = 1 , high = model_config .vocab_size , size = (batch_size , seq_len ), device = device )
69
+ attention_mask = torch .ones ((batch_size , seq_len ), dtype = torch .int64 , device = device )
70
+ labels = input_ids .clone ()
71
+ return {"input_ids" : input_ids , "attention_mask" : attention_mask , "labels" : labels }
72
72
73
73
74
74
def run_test_all_rank_results_match (rank : int , world_size : int , model_id : str , model_kwargs : Dict [str , Any ]):
@@ -82,8 +82,8 @@ def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, m
82
82
83
83
model = parallelize_model (model_id , ctx , skip_load_weights = True , ** model_kwargs )
84
84
inputs = prepare_dummy_inputs (model .config )
85
- logits = model (** inputs )[ 0 ]
86
- tensors = gather_at_main_process (tensor = logits , group = tp_group , rank = rank , world_size = world_size )
85
+ loss = model (** inputs ). loss
86
+ tensors = gather_at_main_process (tensor = loss , group = tp_group , rank = rank , world_size = world_size )
87
87
88
88
# check results at main worker process
89
89
if rank == 0 :
@@ -145,7 +145,7 @@ def run_test_parallel_results_matches_non_parallel(
145
145
inputs = prepare_dummy_inputs (model .config )
146
146
147
147
set_seed (SEED )
148
- logits = model (** inputs )[ 0 ]
148
+ loss = model (** inputs ). loss
149
149
150
150
torch ._dynamo .reset ()
151
151
del model
@@ -154,9 +154,9 @@ def run_test_parallel_results_matches_non_parallel(
154
154
set_seed (SEED )
155
155
ctx = ParallelExecutionCtx (tp_group = tp_group , current_device = device )
156
156
model = parallelize_model (model_id , ctx , skip_load_weights = True , ** model_kwargs )
157
- parallel_logits = model (** inputs )[ 0 ]
157
+ parallel_loss = model (** inputs ). loss
158
158
159
- torch .testing .assert_close (logits .cpu (), parallel_logits .cpu (), rtol = 1e-4 , atol = 1e-4 )
159
+ torch .testing .assert_close (loss .cpu (), parallel_loss .cpu (), rtol = 1e-4 , atol = 1e-4 )
160
160
161
161
dist .barrier (tp_group )
162
162
tearDown ()
0 commit comments