Skip to content

Commit 48465df

Browse files
committed
ruff
1 parent a339ebf commit 48465df

8 files changed

+80
-28
lines changed

.DS_Store

8 KB
Binary file not shown.

src/lobster/cmdline/_utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,15 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]:
1313
return callbacks
1414

1515
if not isinstance(callbacks_cfg, DictConfig):
16-
raise TypeError("[instantiate_callbacks] Callbacks config must be a DictConfig!")
16+
raise TypeError(
17+
"[instantiate_callbacks] Callbacks config must be a DictConfig!"
18+
)
1719

1820
for _, cb_conf in callbacks_cfg.items():
1921
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
20-
print(f"[instantiate_callbacks] Instantiating callback <{cb_conf._target_}>")
22+
print(
23+
f"[instantiate_callbacks] Instantiating callback <{cb_conf._target_}>"
24+
)
2125
callbacks.append(hydra.utils.instantiate(cb_conf))
2226

2327
return callbacks

src/lobster/data/_collate.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
9999
if self._contact_maps:
100100
# Flatten the output of Atom3D transforms
101101
flattened_batch = [
102-
(a, b, c) for ((a, b), c) in raw_batch if (a is not None) and (b is not None)
102+
(a, b, c)
103+
for ((a, b), c) in raw_batch
104+
if (a is not None) and (b is not None)
103105
]
104106
batch_size = len(flattened_batch)
105107
if batch_size == 0:
@@ -114,8 +116,12 @@ def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
114116

115117
if self.truncation_seq_length:
116118
# NOTE - This removes eos token for long sequences. Should we re-add eos or keep as is?
117-
seq1_tokenized = [seq[: self.truncation_seq_length] for seq in seq1_tokenized]
118-
seq2_tokenized = [seq[: self.truncation_seq_length] for seq in seq2_tokenized]
119+
seq1_tokenized = [
120+
seq[: self.truncation_seq_length] for seq in seq1_tokenized
121+
]
122+
seq2_tokenized = [
123+
seq[: self.truncation_seq_length] for seq in seq2_tokenized
124+
]
119125

120126
tokens1 = pad_sequence(
121127
seq1_tokenized, batch_first=True, padding_value=self.tokenizer.pad_token_id

src/lobster/data/_farthest_first_traversal.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33

44
class FarthestFirstTraversal:
55
def __init__(
6-
self, num_samples: int, k: int = 10, random_seed: int = 0xDEADBEEF, p_norm: int = 2
6+
self,
7+
num_samples: int,
8+
k: int = 10,
9+
random_seed: int = 0xDEADBEEF,
10+
p_norm: int = 2,
711
):
812
"""
913
Parameters
@@ -45,7 +49,10 @@ def str_fft(self, inputs: list[str]):
4549
inputs = [inputs[i] for i in perm]
4650
centroids = [inputs[i] for i in range(self._k)]
4751
while len(centroids) < self._num_samples:
48-
dist = [min(self._levenshtein(str1, str2) for str2 in centroids) for str1 in inputs]
52+
dist = [
53+
min(self._levenshtein(str1, str2) for str2 in centroids)
54+
for str1 in inputs
55+
]
4956
farthest = dist.index(max(dist))
5057
if inputs[farthest] in centroids:
5158
break

src/lobster/metrics/_binary_classification.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def summarize_binary_classification_metrics(preds, labels):
2222
"""
2323
# Initialize metric objects
2424
accuracy = Accuracy(task="binary")
25-
precision = Precision(task="binary", num_classes=2, average="micro") # binary classification
25+
precision = Precision(
26+
task="binary", num_classes=2, average="micro"
27+
) # binary classification
2628
recall = Recall(task="binary", num_classes=2, average="micro")
2729
f1_score = F1Score(task="binary", num_classes=2, average="micro")
2830
auroc = AUROC(task="binary", num_classes=1) # binary classification

src/lobster/model/_seq2seq_configuration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
key_bias=True,
135135
value_bias=True,
136136
intermediate_bias=True,
137-
**kwargs
137+
**kwargs,
138138
):
139139
super().__init__(pad_token_id=pad_token_id, **kwargs)
140140

src/lobster/model/hyena/_hyena_base.py

+50-17
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,14 @@ def __init__(self, config, **kwargs):
139139
act,
140140
)
141141
for _i in range(config.num_inner_mlps):
142-
self.implicit_filter.append(nn.Linear(config.filter_order, config.filter_order))
142+
self.implicit_filter.append(
143+
nn.Linear(config.filter_order, config.filter_order)
144+
)
143145
self.implicit_filter.append(act)
144146

145-
self.implicit_filter.append(nn.Linear(config.filter_order, config.d_model, bias=False))
147+
self.implicit_filter.append(
148+
nn.Linear(config.filter_order, config.d_model, bias=False)
149+
)
146150

147151
self.modulation = HyenaExponentialModulation(config.d_model)
148152

@@ -191,7 +195,11 @@ def __init__(
191195
self.out_proj = nn.Linear(self.d_model, self.d_model)
192196

193197
self.short_filter = nn.Conv1d(
194-
inner_width, inner_width, config.short_filter_order, padding=2, groups=inner_width
198+
inner_width,
199+
inner_width,
200+
config.short_filter_order,
201+
padding=2,
202+
groups=inner_width,
195203
)
196204
self.filter_fn = HyenaFilter(config)
197205

@@ -297,7 +305,9 @@ def __init__(self, config, padding_idx=None):
297305
vocab_size += config.pad_vocab_size_multiple - (
298306
vocab_size % config.pad_vocab_size_multiple
299307
)
300-
self.word_embeddings = nn.Embedding(vocab_size, config.d_model, padding_idx=padding_idx)
308+
self.word_embeddings = nn.Embedding(
309+
vocab_size, config.d_model, padding_idx=padding_idx
310+
)
301311

302312
def forward(self, input_ids):
303313
"""
@@ -330,7 +340,9 @@ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
330340

331341
for layer in self.layers:
332342
if self.gradient_checkpointing and self.training:
333-
hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states)
343+
hidden_states = self._gradient_checkpointing_func(
344+
layer.__call__, hidden_states
345+
)
334346
else:
335347
hidden_states = layer(hidden_states)
336348
if output_hidden_states:
@@ -349,7 +361,9 @@ class HyenaDNAPreTrainedModel(PreTrainedModel):
349361
supports_gradient_checkpointing = True
350362
_no_split_modules = ["HyenaBlock"]
351363
_skip_keys_device_placement = "past_key_values"
352-
_keys_to_ignore_on_load_missing = [r"freq"] # Shared tensors that safetensors merges
364+
_keys_to_ignore_on_load_missing = [
365+
r"freq"
366+
] # Shared tensors that safetensors merges
353367

354368
def _init_weights(self, module, initializer_range=0.02):
355369
if isinstance(module, nn.Linear):
@@ -368,13 +382,17 @@ def _init_weights(self, module, initializer_range=0.02):
368382
if name in ["out_proj.weight", "fc2.weight"]:
369383
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
370384
nn.init.normal_(
371-
p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers)
385+
p,
386+
mean=0.0,
387+
std=initializer_range / math.sqrt(2 * self.config.num_layers),
372388
)
373389
# If using GLU activation for now, we scale the std by 2
374390
elif name in ["output_linear.0.weight"]:
375391
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
376392
nn.init.normal_(
377-
p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers)
393+
p,
394+
mean=0.0,
395+
std=initializer_range / math.sqrt(2 * self.config.num_layers),
378396
)
379397

380398

@@ -388,16 +406,22 @@ def __init__(self, config, **kwargs) -> None:
388406
# Initialize weights and apply final processing
389407
self.post_init()
390408

391-
def forward(self, input_ids, inputs_embeds=None, output_hidden_states=None, return_dict=None):
409+
def forward(
410+
self, input_ids, inputs_embeds=None, output_hidden_states=None, return_dict=None
411+
):
392412
output_hidden_states = (
393413
output_hidden_states
394414
if output_hidden_states is not None
395415
else self.config.output_hidden_states
396416
)
397-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
417+
return_dict = (
418+
return_dict if return_dict is not None else self.config.use_return_dict
419+
)
398420

399421
hidden_states, all_hidden_states = self.backbone(
400-
input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states
422+
input_ids,
423+
inputs_embeds=inputs_embeds,
424+
output_hidden_states=output_hidden_states,
401425
)
402426
if return_dict:
403427
return BaseModelOutputWithNoAttention(
@@ -451,13 +475,14 @@ def forward(
451475
output_hidden_states: Optional[bool] = None,
452476
return_dict: Optional[bool] = None,
453477
) -> Union[Tuple, CausalLMOutput]:
454-
455478
output_hidden_states = (
456479
output_hidden_states
457480
if output_hidden_states is not None
458481
else self.config.output_hidden_states
459482
)
460-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
483+
return_dict = (
484+
return_dict if return_dict is not None else self.config.use_return_dict
485+
)
461486

462487
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
463488
outputs = self.hyena(
@@ -525,7 +550,9 @@ def forward(
525550
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
526551
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
527552
"""
528-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
553+
return_dict = (
554+
return_dict if return_dict is not None else self.config.use_return_dict
555+
)
529556

530557
transformer_outputs = self.hyena(
531558
input_ids,
@@ -542,7 +569,9 @@ def forward(
542569
batch_size = inputs_embeds.shape[0]
543570

544571
if self.config.pad_token_id is None and batch_size != 1:
545-
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
572+
raise ValueError(
573+
"Cannot handle batch sizes > 1 if no padding token is defined."
574+
)
546575
if self.config.pad_token_id is None:
547576
sequence_lengths = -1
548577
else:
@@ -553,7 +582,9 @@ def forward(
553582
else:
554583
sequence_lengths = -1
555584

556-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
585+
pooled_logits = logits[
586+
torch.arange(batch_size, device=logits.device), sequence_lengths
587+
]
557588

558589
loss = None
559590
if labels is not None:
@@ -576,7 +607,9 @@ def forward(
576607
loss = loss_fct(pooled_logits, labels)
577608
elif self.config.problem_type == "single_label_classification":
578609
loss_fct = nn.CrossEntropyLoss()
579-
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
610+
loss = loss_fct(
611+
pooled_logits.view(-1, self.num_labels), labels.view(-1)
612+
)
580613
elif self.config.problem_type == "multi_label_classification":
581614
loss_fct = nn.BCEWithLogitsLoss()
582615
loss = loss_fct(pooled_logits, labels)

src/lobster/tokenization/_cached_bert_tokenizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
pad_token: str = "[PAD]",
2323
cls_token: str = "[CLS]",
2424
mask_token: str = "[MASK]",
25-
**kwargs
25+
**kwargs,
2626
):
2727
super().__init__(
2828
vocab_file=vocab_file,
@@ -33,7 +33,7 @@ def __init__(
3333
pad_token=pad_token,
3434
cls_token=cls_token,
3535
mask_token=mask_token,
36-
**kwargs
36+
**kwargs,
3737
)
3838
self.padding_idx = self.vocab[pad_token]
3939
self.masking_idx = self.vocab[mask_token]

0 commit comments

Comments
 (0)