@@ -353,8 +353,17 @@ def _reorder_cache(self, *args, **kwargs):
353
353
def prepare_inputs_for_generation (self , * args , ** kwargs ):
354
354
return self .model .prepare_inputs_for_generation (* args , ** kwargs )
355
355
356
+ def _supports_logits_to_keep (self ) -> bool :
357
+ """
358
+ Return True if the current model supports the keyword argument `logits_to_keep` in forward()
359
+ to save memory. Checking it in this way allows to avoid using a new model attribute.
360
+ """
361
+ return "logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
362
+
356
363
def _supports_num_logits_to_keep (self ) -> bool :
357
364
"""
365
+ Will be deprecated after we no longer support transformers < 4.49
366
+
358
367
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
359
368
to save memory. Checking it in this way allows to avoid using a new model attribute.
360
369
"""
@@ -470,8 +479,17 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
470
479
def get_encoder (self , * args , ** kwargs ):
471
480
return self .model .get_encoder (* args , ** kwargs )
472
481
482
+ def _supports_logits_to_keep (self ) -> bool :
483
+ """
484
+ Return True if the current model supports the keyword argument `logits_to_keep` in forward()
485
+ to save memory. Checking it in this way allows to avoid using a new model attribute.
486
+ """
487
+ return "logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
488
+
473
489
def _supports_num_logits_to_keep (self ) -> bool :
474
490
"""
491
+ Will be deprecated after we no longer support transformers < 4.49
492
+
475
493
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
476
494
to save memory. Checking it in this way allows to avoid using a new model attribute.
477
495
"""
0 commit comments