@@ -93,6 +93,31 @@ def __init__(
93
93
self .eos_token_id = eos_token_id
94
94
95
95
96
+ # Measure FLOPS - for forward pass only
97
+ self .flops_per_batch = lightning .fabric .utilities .throughput .measure_flops (self .model , self ._sample_forward )
98
+
99
+ def _sample_forward (self ):
100
+ batch_size = 64 # FIXME
101
+ batch = {
102
+ "input_ids" : torch .randint (0 , self .vocab_size , (batch_size , self .max_length )),
103
+ "attention_mask" : torch .ones (batch_size , self .max_length ),
104
+ }
105
+
106
+ tokens = batch ["input_ids" ]
107
+ B , length = tokens .shape
108
+ tokens = tokens .view (- 1 )
109
+ attention_mask = batch ["attention_mask" ].view (- 1 )
110
+
111
+ cu_seqlens = torch .tensor ([0 ] + [(i + 1 ) * length for i in range (B )], dtype = torch .int32 ).cuda ()
112
+
113
+ return self .model (
114
+ tokens ,
115
+ attention_mask = attention_mask ,
116
+ cu_seqlens = cu_seqlens ,
117
+ max_seqlen = self .max_length
118
+ )
119
+
120
+
96
121
def training_step (self , batch , batch_idx ):
97
122
loss = self ._compute_loss (batch )
98
123
ppl = torch .exp (loss )
@@ -200,33 +225,5 @@ def _mask_inputs(self, train_inputs: torch.Tensor):
200
225
201
226
return masked_inputs
202
227
203
- def setup (self , stage : str | None ):
204
- """Used to measure FLOPs"""
205
-
206
- with torch .device ("meta" ):
207
- model = FlexBERT (self .config )
208
-
209
- def sample_forward ():
210
- batch_size = 64 # TODO figure out how to avoid setting this manually
211
- batch = {
212
- "input_ids" : torch .randint (0 , model .vocab_size , (batch_size , model .max_length )),
213
- "attention_mask" : torch .ones (batch_size , model .max_length ),
214
- }
215
228
216
- tokens = batch ["input_ids" ]
217
- B , length = tokens .shape
218
- tokens = tokens .view (- 1 )
219
- attention_mask = batch ["attention_mask" ].view (- 1 )
220
-
221
- cu_seqlens = torch .tensor ([0 ] + [(i + 1 ) * length for i in range (B )], dtype = torch .int32 ).cuda ()
222
-
223
- return model .model (
224
- tokens ,
225
- attention_mask = attention_mask ,
226
- cu_seqlens = cu_seqlens ,
227
- max_seqlen = model .max_length
228
- )
229
-
230
- # Measure FLOPS - for forward pass only
231
- self .flops_per_batch = lightning .fabric .utilities .throughput .measure_flops (model , sample_forward )
232
229
0 commit comments