@@ -139,10 +139,14 @@ def __init__(self, config, **kwargs):
139
139
act ,
140
140
)
141
141
for _i in range (config .num_inner_mlps ):
142
- self .implicit_filter .append (nn .Linear (config .filter_order , config .filter_order ))
142
+ self .implicit_filter .append (
143
+ nn .Linear (config .filter_order , config .filter_order )
144
+ )
143
145
self .implicit_filter .append (act )
144
146
145
- self .implicit_filter .append (nn .Linear (config .filter_order , config .d_model , bias = False ))
147
+ self .implicit_filter .append (
148
+ nn .Linear (config .filter_order , config .d_model , bias = False )
149
+ )
146
150
147
151
self .modulation = HyenaExponentialModulation (config .d_model )
148
152
@@ -191,7 +195,11 @@ def __init__(
191
195
self .out_proj = nn .Linear (self .d_model , self .d_model )
192
196
193
197
self .short_filter = nn .Conv1d (
194
- inner_width , inner_width , config .short_filter_order , padding = 2 , groups = inner_width
198
+ inner_width ,
199
+ inner_width ,
200
+ config .short_filter_order ,
201
+ padding = 2 ,
202
+ groups = inner_width ,
195
203
)
196
204
self .filter_fn = HyenaFilter (config )
197
205
@@ -297,7 +305,9 @@ def __init__(self, config, padding_idx=None):
297
305
vocab_size += config .pad_vocab_size_multiple - (
298
306
vocab_size % config .pad_vocab_size_multiple
299
307
)
300
- self .word_embeddings = nn .Embedding (vocab_size , config .d_model , padding_idx = padding_idx )
308
+ self .word_embeddings = nn .Embedding (
309
+ vocab_size , config .d_model , padding_idx = padding_idx
310
+ )
301
311
302
312
def forward (self , input_ids ):
303
313
"""
@@ -330,7 +340,9 @@ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
330
340
331
341
for layer in self .layers :
332
342
if self .gradient_checkpointing and self .training :
333
- hidden_states = self ._gradient_checkpointing_func (layer .__call__ , hidden_states )
343
+ hidden_states = self ._gradient_checkpointing_func (
344
+ layer .__call__ , hidden_states
345
+ )
334
346
else :
335
347
hidden_states = layer (hidden_states )
336
348
if output_hidden_states :
@@ -349,7 +361,9 @@ class HyenaDNAPreTrainedModel(PreTrainedModel):
349
361
supports_gradient_checkpointing = True
350
362
_no_split_modules = ["HyenaBlock" ]
351
363
_skip_keys_device_placement = "past_key_values"
352
- _keys_to_ignore_on_load_missing = [r"freq" ] # Shared tensors that safetensors merges
364
+ _keys_to_ignore_on_load_missing = [
365
+ r"freq"
366
+ ] # Shared tensors that safetensors merges
353
367
354
368
def _init_weights (self , module , initializer_range = 0.02 ):
355
369
if isinstance (module , nn .Linear ):
@@ -368,13 +382,17 @@ def _init_weights(self, module, initializer_range=0.02):
368
382
if name in ["out_proj.weight" , "fc2.weight" ]:
369
383
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
370
384
nn .init .normal_ (
371
- p , mean = 0.0 , std = initializer_range / math .sqrt (2 * self .config .num_layers )
385
+ p ,
386
+ mean = 0.0 ,
387
+ std = initializer_range / math .sqrt (2 * self .config .num_layers ),
372
388
)
373
389
# If using GLU activation for now, we scale the std by 2
374
390
elif name in ["output_linear.0.weight" ]:
375
391
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
376
392
nn .init .normal_ (
377
- p , mean = 0.0 , std = initializer_range / math .sqrt (2 * self .config .num_layers )
393
+ p ,
394
+ mean = 0.0 ,
395
+ std = initializer_range / math .sqrt (2 * self .config .num_layers ),
378
396
)
379
397
380
398
@@ -388,16 +406,22 @@ def __init__(self, config, **kwargs) -> None:
388
406
# Initialize weights and apply final processing
389
407
self .post_init ()
390
408
391
- def forward (self , input_ids , inputs_embeds = None , output_hidden_states = None , return_dict = None ):
409
+ def forward (
410
+ self , input_ids , inputs_embeds = None , output_hidden_states = None , return_dict = None
411
+ ):
392
412
output_hidden_states = (
393
413
output_hidden_states
394
414
if output_hidden_states is not None
395
415
else self .config .output_hidden_states
396
416
)
397
- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
417
+ return_dict = (
418
+ return_dict if return_dict is not None else self .config .use_return_dict
419
+ )
398
420
399
421
hidden_states , all_hidden_states = self .backbone (
400
- input_ids , inputs_embeds = inputs_embeds , output_hidden_states = output_hidden_states
422
+ input_ids ,
423
+ inputs_embeds = inputs_embeds ,
424
+ output_hidden_states = output_hidden_states ,
401
425
)
402
426
if return_dict :
403
427
return BaseModelOutputWithNoAttention (
@@ -451,13 +475,14 @@ def forward(
451
475
output_hidden_states : Optional [bool ] = None ,
452
476
return_dict : Optional [bool ] = None ,
453
477
) -> Union [Tuple , CausalLMOutput ]:
454
-
455
478
output_hidden_states = (
456
479
output_hidden_states
457
480
if output_hidden_states is not None
458
481
else self .config .output_hidden_states
459
482
)
460
- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
483
+ return_dict = (
484
+ return_dict if return_dict is not None else self .config .use_return_dict
485
+ )
461
486
462
487
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
463
488
outputs = self .hyena (
@@ -525,7 +550,9 @@ def forward(
525
550
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
526
551
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
527
552
"""
528
- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
553
+ return_dict = (
554
+ return_dict if return_dict is not None else self .config .use_return_dict
555
+ )
529
556
530
557
transformer_outputs = self .hyena (
531
558
input_ids ,
@@ -542,7 +569,9 @@ def forward(
542
569
batch_size = inputs_embeds .shape [0 ]
543
570
544
571
if self .config .pad_token_id is None and batch_size != 1 :
545
- raise ValueError ("Cannot handle batch sizes > 1 if no padding token is defined." )
572
+ raise ValueError (
573
+ "Cannot handle batch sizes > 1 if no padding token is defined."
574
+ )
546
575
if self .config .pad_token_id is None :
547
576
sequence_lengths = - 1
548
577
else :
@@ -553,7 +582,9 @@ def forward(
553
582
else :
554
583
sequence_lengths = - 1
555
584
556
- pooled_logits = logits [torch .arange (batch_size , device = logits .device ), sequence_lengths ]
585
+ pooled_logits = logits [
586
+ torch .arange (batch_size , device = logits .device ), sequence_lengths
587
+ ]
557
588
558
589
loss = None
559
590
if labels is not None :
@@ -576,7 +607,9 @@ def forward(
576
607
loss = loss_fct (pooled_logits , labels )
577
608
elif self .config .problem_type == "single_label_classification" :
578
609
loss_fct = nn .CrossEntropyLoss ()
579
- loss = loss_fct (pooled_logits .view (- 1 , self .num_labels ), labels .view (- 1 ))
610
+ loss = loss_fct (
611
+ pooled_logits .view (- 1 , self .num_labels ), labels .view (- 1 )
612
+ )
580
613
elif self .config .problem_type == "multi_label_classification" :
581
614
loss_fct = nn .BCEWithLogitsLoss ()
582
615
loss = loss_fct (pooled_logits , labels )
0 commit comments