Skip to content

Commit daecdac

Browse files
committed
move inputs modification into forward
1 parent 59c8c40 commit daecdac

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

optimum/intel/openvino/modeling_decoder.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ def prepare_inputs(
380380
**kwargs,
381381
) -> Dict:
382382
batch_size = input_ids.shape[0]
383-
duplication_indices = None
384383
if self.config.model_type == "bloom":
385384
batch_size *= self.config.num_attention_heads
386385

@@ -463,9 +462,7 @@ def prepare_inputs(
463462
self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)
464463
)
465464

466-
if self._first_iter_beam_search:
467-
inputs, duplication_indices = self._deduplicate_inputs(inputs)
468-
return inputs, duplication_indices
465+
return inputs
469466

470467
def forward(
471468
self,
@@ -477,13 +474,16 @@ def forward(
477474
) -> CausalLMOutputWithPast:
478475
self.compile()
479476

480-
inputs, duplication_idicies = self.prepare_inputs(
477+
inputs = self.prepare_inputs(
481478
input_ids=input_ids,
482479
attention_mask=attention_mask,
483480
past_key_values=past_key_values,
484481
position_ids=position_ids,
485482
**kwargs,
486483
)
484+
485+
if self._first_iter_beam_search:
486+
inputs, duplication_indices = self._deduplicate_inputs(inputs)
487487
# Run inference
488488
self.request.start_async(inputs, share_inputs=True)
489489
self.request.wait()
@@ -512,7 +512,7 @@ def forward(
512512
past_key_values = None
513513

514514
if self._first_iter_beam_search:
515-
logits, past_key_values = self._expand_outputs_for_generation(duplication_idicies, logits, past_key_values)
515+
logits, past_key_values = self._expand_outputs_for_generation(duplication_indices, logits, past_key_values)
516516
self._first_iter_beam_search = False
517517

518518
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)

optimum/intel/openvino/quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def _prepare_builtin_dataset(self, quantization_config: OVWeightQuantizationConf
688688
nsamples = quantization_config.num_samples if quantization_config.num_samples else 128
689689
calibration_dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32, nsamples=nsamples)
690690
calibration_dataset = prepare_dataset(calibration_dataset)
691-
calibration_dataset = nncf.Dataset(calibration_dataset, lambda x: self.model.prepare_inputs(**x)[0])
691+
calibration_dataset = nncf.Dataset(calibration_dataset, lambda x: self.model.prepare_inputs(**x))
692692

693693
return calibration_dataset
694694

0 commit comments

Comments
 (0)