@@ -353,7 +353,7 @@ def _reorder_cache(self, *args, **kwargs):
353
353
def prepare_inputs_for_generation (self , * args , ** kwargs ):
354
354
return self .model .prepare_inputs_for_generation (* args , ** kwargs )
355
355
356
- def _supports_logits_to_keep (self ) -> bool :
356
+ def _supports_num_logits_to_keep (self ) -> bool :
357
357
"""
358
358
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
359
359
to save memory. Checking it in this way allows to avoid using a new model attribute.
@@ -471,12 +471,13 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
471
471
def get_encoder (self , * args , ** kwargs ):
472
472
return self .model .get_encoder (* args , ** kwargs )
473
473
474
- def _supports_logits_to_keep (self ) -> bool :
474
+ def _supports_num_logits_to_keep (self ) -> bool :
475
475
"""
476
476
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
477
477
to save memory. Checking it in this way allows to avoid using a new model attribute.
478
478
"""
479
- return "logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
479
+ logits_to_keep_name = "logits_to_keep" if is_transformers_version (">" , "4.49" ) else "num_logits_to_keep"
480
+ return logits_to_keep_name in set (inspect .signature (self .model .forward ).parameters .keys ())
480
481
481
482
def _init_warmup (self ):
482
483
inputs = prepare_jit_inputs (self .model , self .export_feature , False )
0 commit comments