Skip to content

Commit 652a15c

Browse files
Fix collecting duplicate tensors in quantization calibration dataset (#577)
* Added deepcopying of inputs collected by InferRequestWrapper. Added a test covering the fixed issue. * Phrasing tweaks * Add soundfile to test requirements * Added librosa to test requirements * Added copying to other data cache appends * Remove the need for real test data * Process __call__ call properly * Addressed suggested changes
1 parent 2d8307e commit 652a15c

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

optimum/intel/openvino/quantization.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
import inspect
1617
import logging
1718
import os
@@ -87,11 +88,14 @@ def __init__(self, request, data_cache=None):
8788
self.data_cache = data_cache
8889

8990
def __call__(self, *args, **kwargs):
90-
self.data_cache.append(*args)
91+
# If __call__ is invoked then self.request must be an instance of CompiledModel
92+
signature = inspect.signature(self.request)
93+
bound_args = signature.bind(*args, **kwargs).arguments
94+
self.data_cache.append(copy.deepcopy(bound_args["inputs"]))
9195
return self.request(*args, **kwargs)
9296

9397
def infer(self, inputs: Any = None, share_inputs: bool = False):
94-
self.data_cache.append(inputs)
98+
self.data_cache.append(copy.deepcopy(inputs))
9599
return self.request.infer(inputs, share_inputs)
96100

97101
def start_async(
@@ -102,7 +106,7 @@ def start_async(
102106
*,
103107
shared_memory: Any = None,
104108
):
105-
self.data_cache.append(inputs)
109+
self.data_cache.append(copy.deepcopy(inputs))
106110
self.request.infer(inputs, share_inputs, share_outputs=True)
107111

108112
def wait(self):

tests/openvino/test_quantization.py

+40
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
import tempfile
1818
import unittest
19+
from collections import defaultdict
1920
from functools import partial
2021

2122
import evaluate
2223
import numpy as np
24+
import torch
2325
from datasets import load_dataset
2426
from parameterized import parameterized
2527
import openvino.runtime as ov
@@ -30,6 +32,7 @@
3032
AutoModelForCausalLM,
3133
AutoModelForTokenClassification,
3234
AutoTokenizer,
35+
AutoProcessor,
3336
TrainingArguments,
3437
default_data_collator,
3538
)
@@ -45,6 +48,7 @@
4548
OVModelForSeq2SeqLM,
4649
OVModelForSequenceClassification,
4750
OVModelForTokenClassification,
51+
OVModelForSpeechSeq2Seq,
4852
OVStableDiffusionPipeline,
4953
OVStableDiffusionXLPipeline,
5054
OVQuantizer,
@@ -54,6 +58,7 @@
5458

5559

5660
from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG
61+
from optimum.intel.openvino.quantization import InferRequestWrapper
5762
from optimum.intel.utils.import_utils import is_openvino_version
5863
from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8
5964

@@ -589,3 +594,38 @@ def compute_metrics(p):
589594
tokens = tokenizer("This is a sample input", return_tensors="pt")
590595
outputs = model(**tokens)
591596
self.assertTrue("logits" in outputs)
597+
598+
599+
class InferRequestWrapperTest(unittest.TestCase):
600+
MODEL_ID = ("openai/whisper-tiny.en",)
601+
602+
@staticmethod
603+
def _generate_random_audio_data(processor):
604+
t = np.linspace(0, 1.0, int(1000), endpoint=False)
605+
audio_data = 0.5 * np.sin((2 + np.random.random()) * np.pi * t)
606+
input_features = processor(
607+
audio_data,
608+
sampling_rate=16000,
609+
return_tensors="pt",
610+
).input_features
611+
return input_features
612+
613+
@parameterized.expand(MODEL_ID)
614+
def test_calibration_data_uniqueness(self, model_id):
615+
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, compile=True)
616+
processor = AutoProcessor.from_pretrained(model_id)
617+
618+
calibration_data = []
619+
ov_model.decoder_with_past.request = InferRequestWrapper(ov_model.decoder_with_past.request, calibration_data)
620+
for _ in range(2):
621+
input_features = self._generate_random_audio_data(processor)
622+
ov_model.generate(input_features)
623+
624+
data_hashes_per_key = defaultdict(list)
625+
for inputs_dict in calibration_data:
626+
for k, v in inputs_dict.items():
627+
x = (v.numpy() if isinstance(v, torch.Tensor) else v).copy()
628+
data_hashes_per_key[k].append(hash(x.tobytes()))
629+
for k, data_hashes in data_hashes_per_key.items():
630+
# All hashes can not be equal because calibration dataset contains at least 2 different samples
631+
self.assertTrue(any(data_hashes[0] != it for it in data_hashes))

0 commit comments

Comments
 (0)