@@ -358,8 +358,14 @@ def _supports_logits_to_keep(self) -> bool:
358
358
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
359
359
to save memory. Checking it in this way allows to avoid using a new model attribute.
360
360
"""
361
- logits_to_keep_name = "logits_to_keep" if is_transformers_version (">" , "4.49" ) else "num_logits_to_keep"
362
- return logits_to_keep_name in set (inspect .signature (self .model .forward ).parameters .keys ())
361
+ return "logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
362
+
363
+ def _supports_num_logits_to_keep (self ) -> bool :
364
+ """
365
+ Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
366
+ to save memory. Checking it in this way allows to avoid using a new model attribute.
367
+ """
368
+ return "num_logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
363
369
364
370
def generate (self , * args , ** kwargs ):
365
371
if self ._add_patch and kwargs .get ("assistant_model" , None ):
@@ -478,6 +484,13 @@ def _supports_logits_to_keep(self) -> bool:
478
484
"""
479
485
return "logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
480
486
487
+ def _supports_num_logits_to_keep (self ) -> bool :
488
+ """
489
+ Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
490
+ to save memory. Checking it in this way allows to avoid using a new model attribute.
491
+ """
492
+ return "num_logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
493
+
481
494
def _init_warmup (self ):
482
495
inputs = prepare_jit_inputs (self .model , self .export_feature , False )
483
496
self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
0 commit comments