Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deepcopy of ov.Tensor #1146

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
PREDEFINED_SD_DATASETS,
PREDEFINED_SPEECH_TO_TEXT_DATASETS,
PREDEFINED_VISUAL_LM_DATASETS,
deepcopy_data,
)


Expand Down Expand Up @@ -131,7 +132,7 @@ def __init__(

def collect_inputs(self, inputs):
if not self.apply_caching or not isinstance(inputs, dict):
self.collected_inputs.append(copy.deepcopy(inputs))
self.collected_inputs.append(deepcopy_data(inputs))
return

copied_inputs = {}
Expand All @@ -146,7 +147,7 @@ def collect_inputs(self, inputs):
# Avoid data copying if tensor contains data encountered earlier
self.tensor_cache.setdefault(k, {})
if data_hash not in self.tensor_cache[k]:
self.tensor_cache[k][data_hash] = copy.deepcopy(v)
self.tensor_cache[k][data_hash] = deepcopy_data(v)
copied_inputs[k] = self.tensor_cache[k][data_hash]
self.collected_inputs.append(copied_inputs)

Expand Down
22 changes: 20 additions & 2 deletions optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
import stat
import warnings
import weakref
from copy import deepcopy
from glob import glob
from pathlib import Path
from tempfile import TemporaryDirectory as OrigTemporaryDirectory
from tempfile import mkdtemp
from typing import Tuple, Type, Union
from typing import Any, Tuple, Type, Union

import numpy as np
import torch
from huggingface_hub import model_info
from openvino.runtime import Core, Model, properties
from openvino.runtime import Core, Model, Tensor, properties
from openvino.runtime import Type as OVType
from packaging.version import Version
from transformers import AutoTokenizer, CLIPTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
Expand Down Expand Up @@ -586,3 +587,20 @@ def check_scale_available(model: Union[Model, str, Path]):
if runtime_options is None:
return False
return runtime_options.find("ACTIVATIONS_SCALE_FACTOR") is not None


def deepcopy_data(inputs: Any) -> Any:
if isinstance(inputs, dict):
new_inputs = {}
for k, v in inputs.items():
new_inputs[deepcopy_data(k)] = deepcopy_data(v)
elif isinstance(inputs, list):
new_inputs = [deepcopy_data(elem) for elem in inputs]
elif isinstance(inputs, tuple):
new_inputs = tuple(deepcopy_data(elem) for elem in inputs)
elif isinstance(inputs, Tensor):
new_inputs = Tensor(np.empty_like(inputs.data), inputs.get_shape(), inputs.get_element_type())
new_inputs.copy_from(inputs)
else:
new_inputs = deepcopy(inputs)
return new_inputs
39 changes: 38 additions & 1 deletion tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
_DEFAULT_4BIT_CONFIGS,
_DEFAULT_4BIT_CONFIG,
)
from optimum.intel.openvino.utils import TemporaryDirectory
from optimum.intel.openvino.utils import TemporaryDirectory, deepcopy_data
from copy import deepcopy

from optimum.intel.openvino.quantization import InferRequestWrapper
Expand Down Expand Up @@ -1354,6 +1354,43 @@ def test_calibration_data_uniqueness(self, model_name, apply_caching):
# Without caching, encoder hidden states tensors will be unique for each collected input
self.assertGreater(len(data_id_per_key["encoder_hidden_states"]), 2)

def test_deepcopy_data(self):
data = {
"a": torch.tensor([1, 2, 3]),
"b": np.array([1, 2, 3]),
"c": 1,
"d": "string",
"e": {"a": torch.tensor([1, 2, 3]), "b": np.array([1, 2, 3])},
"f": [
ov.Tensor(np.ones((1, 2, 3)), (1, 2, 3), ov.Type.i4),
ov.Tensor(np.empty((1, 2, 3)), (2, 3, 1), ov.Type.i4),
],
}
copied_data = deepcopy_data(data)

# Checks that objects have different IDs
self.assertTrue(copied_data is not data)
self.assertTrue(copied_data["a"] is not data["a"])
self.assertTrue(copied_data["b"] is not data["b"])
self.assertTrue(copied_data["e"]["a"] is not data["e"]["a"])
self.assertTrue(copied_data["e"]["b"] is not data["e"]["b"])
self.assertTrue(copied_data["f"][0] is not data["f"][0])
self.assertTrue(copied_data["f"][1] is not data["f"][1])

# Checks that constant objects have the same IDs
self.assertTrue(copied_data["c"] is data["c"])
self.assertTrue(copied_data["d"] is data["d"])

# Checks that objects have the same data
self.assertTrue(torch.equal(copied_data["a"], data["a"]))
self.assertTrue(np.array_equal(copied_data["b"], data["b"]))
self.assertTrue(copied_data["c"] == data["c"])
self.assertTrue(copied_data["d"] == data["d"])
self.assertTrue(torch.equal(copied_data["e"]["a"], data["e"]["a"]))
self.assertTrue(np.array_equal(copied_data["e"]["b"], data["e"]["b"]))
self.assertTrue(np.array_equal(copied_data["f"][0].data, data["f"][0].data))
self.assertTrue(np.array_equal(copied_data["f"][1].data, data["f"][1].data))


def check_optimization_not_applicable_to_optimized_model(model, quantization_config):
quantizer = OVQuantizer(model)
Expand Down
Loading