Skip to content

Commit f2a53c0

Browse files
committed
flops
1 parent 58bcd8d commit f2a53c0

File tree

1 file changed

+25
-28
lines changed

1 file changed

+25
-28
lines changed

src/lobster/model/modern_bert/_modern_bert.py

+25-28
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,31 @@ def __init__(
9393
self.eos_token_id = eos_token_id
9494

9595

96+
# Measure FLOPS - for forward pass only
97+
self.flops_per_batch = lightning.fabric.utilities.throughput.measure_flops(self.model, self._sample_forward)
98+
99+
def _sample_forward(self):
100+
batch_size = 64 # FIXME
101+
batch = {
102+
"input_ids": torch.randint(0, self.vocab_size, (batch_size, self.max_length)),
103+
"attention_mask": torch.ones(batch_size, self.max_length),
104+
}
105+
106+
tokens = batch["input_ids"]
107+
B, length = tokens.shape
108+
tokens = tokens.view(-1)
109+
attention_mask = batch["attention_mask"].view(-1)
110+
111+
cu_seqlens = torch.tensor([0] + [(i + 1) * length for i in range(B)], dtype=torch.int32).cuda()
112+
113+
return self.model(
114+
tokens,
115+
attention_mask=attention_mask,
116+
cu_seqlens=cu_seqlens,
117+
max_seqlen=self.max_length
118+
)
119+
120+
96121
def training_step(self, batch, batch_idx):
97122
loss = self._compute_loss(batch)
98123
ppl = torch.exp(loss)
@@ -200,33 +225,5 @@ def _mask_inputs(self, train_inputs: torch.Tensor):
200225

201226
return masked_inputs
202227

203-
def setup(self, stage: str | None):
204-
"""Used to measure FLOPs"""
205-
206-
with torch.device("meta"):
207-
model = FlexBERT(self.config)
208-
209-
def sample_forward():
210-
batch_size = 64 # TODO figure out how to avoid setting this manually
211-
batch = {
212-
"input_ids": torch.randint(0, model.vocab_size, (batch_size, model.max_length)),
213-
"attention_mask": torch.ones(batch_size, model.max_length),
214-
}
215228

216-
tokens = batch["input_ids"]
217-
B, length = tokens.shape
218-
tokens = tokens.view(-1)
219-
attention_mask = batch["attention_mask"].view(-1)
220-
221-
cu_seqlens = torch.tensor([0] + [(i + 1) * length for i in range(B)], dtype=torch.int32).cuda()
222-
223-
return model.model(
224-
tokens,
225-
attention_mask=attention_mask,
226-
cu_seqlens=cu_seqlens,
227-
max_seqlen=model.max_length
228-
)
229-
230-
# Measure FLOPS - for forward pass only
231-
self.flops_per_batch = lightning.fabric.utilities.throughput.measure_flops(model, sample_forward)
232229

0 commit comments

Comments
 (0)