@@ -317,6 +317,13 @@ def _reorder_cache(self, *args, **kwargs):
317
317
def prepare_inputs_for_generation (self , * args , ** kwargs ):
318
318
return self .model .prepare_inputs_for_generation (* args , ** kwargs )
319
319
320
+ def _supports_num_logits_to_keep (self ) -> bool :
321
+ """
322
+ Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
323
+ to save memory. Checking it in this way allows to avoid using a new model attribute.
324
+ """
325
+ return "num_logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
326
+
320
327
def generate (self , * args , ** kwargs ):
321
328
if self ._add_patch and kwargs .get ("assistant_model" , None ):
322
329
raise ValueError (
@@ -427,6 +434,13 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
427
434
def get_encoder (self , * args , ** kwargs ):
428
435
return self .model .get_encoder (* args , ** kwargs )
429
436
437
+ def _supports_num_logits_to_keep (self ) -> bool :
438
+ """
439
+ Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
440
+ to save memory. Checking it in this way allows to avoid using a new model attribute.
441
+ """
442
+ return "num_logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
443
+
430
444
def _init_warmup (self ):
431
445
inputs = prepare_jit_inputs (self .model , self .export_feature , False )
432
446
self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
0 commit comments