@@ -338,7 +338,6 @@ def compile(self):
338
338
if self .compiled_model is None :
339
339
super ().compile ()
340
340
self .compiled_model = self .request
341
- # self.request = self.request.create_infer_request()
342
341
343
342
def _make_stateful (self ):
344
343
patch_stateful (self .config , self .model )
@@ -358,16 +357,11 @@ class OVModelForCausalLM(OVBaseDecoderModel, GenerationMixin):
358
357
359
358
def generate (self , * args , ** kwargs ):
360
359
self .compile ()
361
- infer_context = [self .compiled_model .create_infer_request ()]
362
- kwargs ["infer_context" ] = infer_context
360
+ if kwargs .get ("infer_request" ) is None :
361
+ infer_context = [self .compiled_model .create_infer_request ()]
362
+ kwargs ["infer_context" ] = infer_context
363
363
return super ().generate (* args , ** kwargs )
364
364
365
- def __call__ (self , * args , ** kwargs ):
366
- self .compile ()
367
- infer_context = [self .compiled_model .create_infer_request ()]
368
- kwargs ["infer_context" ] = infer_context
369
- return super ().__call__ (* args , ** kwargs )
370
-
371
365
@add_start_docstrings_to_model_forward (
372
366
INPUTS_DOCSTRING .format ("batch_size, sequence_length" )
373
367
+ TEXT_GENERATION_EXAMPLE .format (
@@ -482,7 +476,7 @@ def forward(
482
476
# for stateful models, infer request is created in generate and __call_ methods and passed in the cycle via past_key_values param
483
477
infer_request = past_key_values [1 ]
484
478
else :
485
- if infer_context [ 0 ] is not None :
479
+ if infer_context is not None :
486
480
infer_request = infer_context [
487
481
0
488
482
] # Use passed inference request if provided in kwargs, create new one overwise
@@ -501,7 +495,7 @@ def forward(
501
495
if not self .stateful :
502
496
if self .use_cache :
503
497
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
504
- past_key_values = tuple (infer_context [ 0 ] .get_tensor (key ).data for key in self .key_value_output_names )
498
+ past_key_values = tuple (infer_request .get_tensor (key ).data for key in self .key_value_output_names )
505
499
if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
506
500
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
507
501
past_key_values = tuple (
@@ -690,9 +684,6 @@ def _reorder_cache(
690
684
batch_size = beam_idx .shape [0 ]
691
685
indices = np .array (range (batch_size * self .config .num_attention_heads ))
692
686
indices = indices .reshape ([batch_size , self .config .num_attention_heads ])
693
- # self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
694
- # return past_key_values
695
- # print("_reorder_cache output",np.take(indices, beam_idx, 0).flatten())
696
687
return ((np .take (indices , beam_idx , 0 ).flatten ()), past_key_values [1 ])
697
688
else :
698
689
standardized_past = self ._convert_to_standard_cache (past_key_values , batch_size = len (beam_idx ))
0 commit comments