@@ -380,7 +380,6 @@ def prepare_inputs(
380
380
** kwargs ,
381
381
) -> Dict :
382
382
batch_size = input_ids .shape [0 ]
383
- duplication_indices = None
384
383
if self .config .model_type == "bloom" :
385
384
batch_size *= self .config .num_attention_heads
386
385
@@ -463,9 +462,7 @@ def prepare_inputs(
463
462
self .next_beam_idx if self .next_beam_idx is not None else np .arange (batch_size , dtype = int )
464
463
)
465
464
466
- if self ._first_iter_beam_search :
467
- inputs , duplication_indices = self ._deduplicate_inputs (inputs )
468
- return inputs , duplication_indices
465
+ return inputs
469
466
470
467
def forward (
471
468
self ,
@@ -477,13 +474,16 @@ def forward(
477
474
) -> CausalLMOutputWithPast :
478
475
self .compile ()
479
476
480
- inputs , duplication_idicies = self .prepare_inputs (
477
+ inputs = self .prepare_inputs (
481
478
input_ids = input_ids ,
482
479
attention_mask = attention_mask ,
483
480
past_key_values = past_key_values ,
484
481
position_ids = position_ids ,
485
482
** kwargs ,
486
483
)
484
+
485
+ if self ._first_iter_beam_search :
486
+ inputs , duplication_indices = self ._deduplicate_inputs (inputs )
487
487
# Run inference
488
488
self .request .start_async (inputs , share_inputs = True )
489
489
self .request .wait ()
@@ -512,7 +512,7 @@ def forward(
512
512
past_key_values = None
513
513
514
514
if self ._first_iter_beam_search :
515
- logits , past_key_values = self ._expand_outputs_for_generation (duplication_idicies , logits , past_key_values )
515
+ logits , past_key_values = self ._expand_outputs_for_generation (duplication_indices , logits , past_key_values )
516
516
self ._first_iter_beam_search = False
517
517
518
518
return CausalLMOutputWithPast (logits = logits , past_key_values = past_key_values )
0 commit comments