From d5f0d48ee0355c8b8a0367d2333b7402dd13a30d Mon Sep 17 00:00:00 2001 From: "Kruglov, Oleg" Date: Fri, 22 Nov 2024 07:56:14 -0800 Subject: [PATCH 01/17] Added numerics methods for tensorflow tensor --- docs/api/source/conf.py | 4 +- nncf/tensor/definitions.py | 1 + nncf/tensor/functions/__init__.py | 4 + nncf/tensor/functions/dispatcher.py | 4 + nncf/tensor/functions/tf_linalg.py | 100 ++++ nncf/tensor/functions/tf_numeric.py | 534 ++++++++++++++++++ nncf/tensor/tensor.py | 2 +- nncf/version.py | 2 +- .../template_test_nncf_tensor.py | 37 +- tests/tensorflow/test_tensor.py | 187 ++++++ 10 files changed, 859 insertions(+), 16 deletions(-) create mode 100644 nncf/tensor/functions/tf_linalg.py create mode 100644 nncf/tensor/functions/tf_numeric.py create mode 100644 tests/tensorflow/test_tensor.py diff --git a/docs/api/source/conf.py b/docs/api/source/conf.py index 42544472705..e008c1f704d 100644 --- a/docs/api/source/conf.py +++ b/docs/api/source/conf.py @@ -145,8 +145,8 @@ def collect_api_entities() -> APIInfo: "nncf.tensor.functions.torch_linalg", "nncf.tensor.functions.torch_io", "nncf.tensor.functions.numpy_io", - "nncf.tensor.functions.openvino_numeric", - "nncf.torch.dynamic_graph.patch_pytorch", + "nncf.tensor.functions.tf_numeric", + "nncf.tensor.functions.tf_linalg", ] with mock(mock_modules): diff --git a/nncf/tensor/definitions.py b/nncf/tensor/definitions.py index 576a812ec7b..75995c3e22b 100644 --- a/nncf/tensor/definitions.py +++ b/nncf/tensor/definitions.py @@ -26,6 +26,7 @@ class TensorBackend(Enum): """ numpy = auto() + tf = auto() torch = auto() ov = auto() diff --git a/nncf/tensor/functions/__init__.py b/nncf/tensor/functions/__init__.py index 568a4444ffc..c5d21dbc5d2 100644 --- a/nncf/tensor/functions/__init__.py +++ b/nncf/tensor/functions/__init__.py @@ -74,6 +74,10 @@ def _initialize_backends() -> None: import nncf.tensor.functions.numpy_linalg import nncf.tensor.functions.numpy_numeric + with contextlib.suppress(ImportError): + import nncf.tensor.functions.tf_linalg + import nncf.tensor.functions.tf_numeric + with contextlib.suppress(ImportError): import nncf.tensor.functions.torch_io import nncf.tensor.functions.torch_linalg diff --git a/nncf/tensor/functions/dispatcher.py b/nncf/tensor/functions/dispatcher.py index e557aaadfbb..c97574cc596 100644 --- a/nncf/tensor/functions/dispatcher.py +++ b/nncf/tensor/functions/dispatcher.py @@ -285,6 +285,10 @@ def get_numeric_backend_fn(fn_name: str, backend: TensorBackend) -> Callable[... from nncf.tensor.functions import torch_numeric return getattr(torch_numeric, fn_name) + if backend == TensorBackend.tf: + from nncf.tensor.functions import tf_numeric + + return getattr(tf_numeric, fn_name) msg = f"Unsupported backend type: {backend}" raise ValueError(msg) diff --git a/nncf/tensor/functions/tf_linalg.py b/nncf/tensor/functions/tf_linalg.py new file mode 100644 index 00000000000..a13e8f48b05 --- /dev/null +++ b/nncf/tensor/functions/tf_linalg.py @@ -0,0 +1,100 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Optional, Tuple, Union + +import tensorflow as tf + +from nncf.tensor.functions import linalg + + +@linalg.norm.register(tf.Tensor) +def _( + a: tf.Tensor, + ord: Optional[Union[str, float, int]] = None, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> tf.Tensor: + if axis is None: + axis = 0 if a._rank() == 1 else (0, 1) + + if ord is None or (a._rank() == 1 and ord == "fro"): + ord = "euclidean" + + with tf.device(a.device): + if ord == "nuc": + s, _, _ = tf.linalg.svd(a) + result = tf.reduce_sum(s) + + if keepdims: + result_shape = [1 if i in axis else dim for i, dim in enumerate(a.shape)] + result = tf.reshape(result, result_shape) + return result + + if ord == 0: + return tf.cast(tf.math.count_nonzero(a, axis=axis, keepdims=keepdims), a.dtype) + + return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims) + + +@linalg.cholesky.register(tf.Tensor) +def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor: + with tf.device(a.device): + cholesky = tf.linalg.cholesky(a) + if upper: + perm = list(range(tf.rank(a))) + perm[-1], perm[-2] = perm[-2], perm[-1] + cholesky = tf.transpose(cholesky, perm=perm) + return cholesky + + +@linalg.cholesky_inverse.register(tf.Tensor) +def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor: + with tf.device(a.device): + if upper: + perm = list(range(tf.rank(a))) + perm[-1], perm[-2] = perm[-2], perm[-1] + a = tf.transpose(a, perm=perm) + + eye = tf.eye(a.shape[0], dtype=a.dtype) + return tf.linalg.cholesky_solve(a, eye) + + +@linalg.inv.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.linalg.inv(a) + + +@linalg.pinv.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.linalg.pinv(a) + + +@linalg.lstsq.register(tf.Tensor) +def _(a: tf.Tensor, b: tf.Tensor, driver: Optional[str] = None) -> tf.Tensor: + with tf.device(a.device): + if driver is not None: + warnings.warn("Driver specifying is not supported in TensorFlow lstsq method") + if tf.rank(b) == 1: + b = tf.expand_dims(b, axis=1) + + return tf.linalg.lstsq(a, b) + + +@linalg.svd.register(tf.Tensor) +def _(a: tf.Tensor, full_matrices: Optional[bool] = True) -> tf.Tensor: + with tf.device(a.device): + s, u, v = tf.linalg.svd(a, full_matrices=full_matrices) + + return u, s, tf.transpose(v) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py new file mode 100644 index 00000000000..4e3afd29c53 --- /dev/null +++ b/nncf/tensor/functions/tf_numeric.py @@ -0,0 +1,534 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from nncf.tensor import TensorDataType +from nncf.tensor import TensorDeviceType +from nncf.tensor.definitions import TensorBackend +from nncf.tensor.definitions import TypeInfo +from nncf.tensor.functions import numeric as numeric + +DTYPE_MAP = { + TensorDataType.float16: tf.float16, + TensorDataType.bfloat16: tf.bfloat16, + TensorDataType.float32: tf.float32, + TensorDataType.float64: tf.float64, + TensorDataType.int8: tf.int8, + TensorDataType.int32: tf.int32, + TensorDataType.int64: tf.int64, + TensorDataType.uint8: tf.uint8, +} + +DEVICE_MAP = {TensorDeviceType.CPU: "CPU", TensorDeviceType.GPU: "GPU"} + +DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()} +DEVICE_MAP_REV = {v: k for k, v in DEVICE_MAP.items()} + + +@numeric.device.register(tf.Tensor) +def _(a: tf.Tensor) -> TensorDeviceType: + return DEVICE_MAP_REV[a.device.split("/")[-1].split(":")[1]] + + +@numeric.backend.register(tf.Tensor) +def _(a: tf.Tensor) -> TensorBackend: + return TensorBackend.tf + + +@numeric.squeeze.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: + with tf.device(a.device): + if axis is None: + return tf.squeeze(a) + if isinstance(axis, Tuple) and any(a.shape[i] != 1 for i in axis): + raise ValueError("Cannot select an axis to squeeze out which has size not equal to one") + return tf.squeeze(a, axis) + + +@numeric.flatten.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.reshape(a, [-1]) + + +@numeric.max.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> tf.Tensor: + with tf.device(a.device): + if axis is None: + return tf.reduce_max(a) + return tf.reduce_max(a, axis=axis, keepdims=keepdim) + + +@numeric.min.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> tf.Tensor: + with tf.device(a.device): + if axis is None: + return tf.reduce_min(a) + return tf.reduce_min(a, axis=axis, keepdims=keepdim) + + +@numeric.abs.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.abs(a) + + +@numeric.astype.register(tf.Tensor) +def _(a: tf.Tensor, dtype: TensorDataType) -> tf.Tensor: + with tf.device(a.device): + return tf.cast(a, DTYPE_MAP[dtype]) + + +@numeric.dtype.register(tf.Tensor) +def _(a: tf.Tensor) -> TensorDataType: + return DTYPE_MAP_REV[a.dtype] + + +@numeric.reshape.register(tf.Tensor) +def _(a: tf.Tensor, shape: Tuple[int, ...]) -> tf.Tensor: + with tf.device(a.device): + return tf.reshape(a, shape) + + +@numeric.all.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: + with tf.device(a.device): + if axis is None: + return tf.reduce_all(a) + return tf.reduce_all(a, axis=axis) + + +@numeric.allclose.register(tf.Tensor) +def _( + a: tf.Tensor, b: Union[tf.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +) -> bool: + with tf.device(a.device): + if not isinstance(b, tf.Tensor): + b = tf.constant(b) + return tf.experimental.numpy.allclose(a, tf.cast(b, a.dtype), rtol=rtol, atol=atol, equal_nan=equal_nan) + + +@numeric.any.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: + with tf.device(a.device): + if axis is None: + return tf.reduce_any(a) + return tf.reduce_any(a, axis=axis) + + +@numeric.count_nonzero.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: + with tf.device(a.device): + return tf.math.count_nonzero(a, axis=axis) + + +@numeric.isempty.register(tf.Tensor) +def _(a: tf.Tensor) -> bool: + return bool(tf.equal(tf.size(a), 0).numpy().T) + + +@numeric.isclose.register(tf.Tensor) +def _( + a: tf.Tensor, b: Union[tf.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +) -> tf.Tensor: + with tf.device(a.device): + return tf.experimental.numpy.isclose(a, tf.cast(b, a.dtype), atol=atol, rtol=rtol, equal_nan=equal_nan) + + +@numeric.maximum.register(tf.Tensor) +def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(x1.device): + return tf.maximum(x1, x2) + + +@numeric.minimum.register(tf.Tensor) +def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(x1.device): + return tf.minimum(x1, x2) + + +@numeric.ones_like.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.ones_like(a) + + +@numeric.where.register(tf.Tensor) +def _(condition: tf.Tensor, x: Union[tf.Tensor, float, bool], y: Union[tf.Tensor, float, bool]) -> tf.Tensor: + with tf.device(condition.device): + return tf.where(condition, x, y) + + +@numeric.zeros_like.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.zeros_like(a) + + +@numeric.stack.register(tf.Tensor) +def _(x: List[tf.Tensor], axis: int = 0) -> tf.Tensor: + with tf.device(x[0].device): + return tf.stack(x, axis=axis) + + +@numeric.concatenate.register(tf.Tensor) +def _(x: List[tf.Tensor], axis: int = 0) -> tf.Tensor: + with tf.device(x[0].device): + return tf.concat(x, axis=axis) + + +@numeric.unstack.register(tf.Tensor) +def _(x: tf.Tensor, axis: int = 0) -> List[tf.Tensor]: + with tf.device(x.device): + if not list(x.shape): + tf.expand_dims(x, 0) + return tf.unstack(x, axis=axis) + + +@numeric.moveaxis.register(tf.Tensor) +def _(a: tf.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> tf.Tensor: + perm = list(range(a._rank())) + if isinstance(source, int): + axe_to_move = perm.pop(source) + if destination < 0: + destination = len(perm) + destination + 1 + perm.insert(destination, axe_to_move) + else: + old_perm = perm[:] + for i in range(len(source)): + perm[destination[i]] = old_perm[source[i]] + with tf.device(a.device): + return tf.transpose(a, perm) + + +@numeric.mean.register(tf.Tensor) +def _( + a: tf.Tensor, + axis: Union[int, Tuple[int, ...]] = None, + keepdims: bool = False, + dtype: Optional[TensorDataType] = None, +) -> tf.Tensor: + with tf.device(a.device): + return tf.reduce_mean(a, axis=axis, keepdims=keepdims) + + +@numeric.median.register(tf.Tensor) +def _( + a: tf.Tensor, + axis: Union[int, Tuple[int, ...]] = None, + keepdims: bool = False, +) -> tf.Tensor: + numpy_a = np.array(a) + numpy_median = np.median(numpy_a, axis=axis, keepdims=keepdims) + + with tf.device(a.device): + tf_median = tf.constant(numpy_median) + + return tf_median + + +@numeric.round.register(tf.Tensor) +def _(a: tf.Tensor, decimals: int = 0) -> tf.Tensor: + scale_factor = 10**decimals + scaled_tensor = a * scale_factor + with tf.device(a.device): + rounded_tensor = tf.round(scaled_tensor) + return rounded_tensor / scale_factor + + +@numeric.power.register(tf.Tensor) +def _(a: tf.Tensor, exponent: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(a.device): + return tf.pow(a, exponent) + + +@numeric.quantile.register(tf.Tensor) +def quantile( + a: tf.Tensor, + q: Union[float, List[float]], + axis: Optional[Union[int, Tuple[int]]] = None, + keepdims: bool = False, +) -> tf.Tensor: + a_np = a.numpy() + quantile_np = np.quantile(a_np, q=q, axis=axis, keepdims=keepdims) + with tf.device(a.device): + return tf.constant(quantile_np) + + +@numeric.percentile.register(tf.Tensor) +def _( + a: tf.Tensor, + q: Union[float, List[float]], + axis: Union[int, Tuple[int, ...], List[int]], + keepdims: bool = False, +) -> List[Union[tf.Tensor, np.generic]]: + with tf.device(a.device): + q = [x / 100 for x in q] if isinstance(q, (list, tuple)) else q / 100 + return numeric.quantile(a, q=q, axis=axis, keepdims=keepdims) + + +@numeric._binary_op_nowarn.register(tf.Tensor) +def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor: + with tf.device(a.device): + return operator_fn(a, b) + + +@numeric._binary_reverse_op_nowarn.register(tf.Tensor) +def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor: + with tf.device(a.device): + return operator_fn(b, a) + + +@numeric.clip.register(tf.Tensor) +def _(a: tf.Tensor, a_min: Union[tf.Tensor, float], a_max: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(a.device): + return tf.clip_by_value(a, a_min, a_max) + + +@numeric.finfo.register(tf.Tensor) +def _(a: tf.Tensor) -> TypeInfo: + ti = tf.experimental.numpy.finfo(a.dtype) + return TypeInfo(ti.eps, ti.max, ti.min) + + +@numeric.as_tensor_like.register(tf.Tensor) +def _(a: tf.Tensor, data: Any) -> tf.Tensor: + with tf.device(a.device): + return tf.convert_to_tensor(data) + + +@numeric.item.register(tf.Tensor) +def _(a: tf.Tensor) -> Union[int, float, bool]: + np_item = a.numpy() + if isinstance(np_item, np.floating): + return float(np_item) + if isinstance(np_item, np.bool_): + return bool(np_item) + + return int(np_item) + + +@numeric.sum.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor: + with tf.device(a.device): + return tf.reduce_sum(a, axis=axis, keepdims=keepdims) + + +@numeric.multiply.register(tf.Tensor) +def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(x1.device): + return tf.multiply(x1, x2) + + +@numeric.var.register(tf.Tensor) +def _( + a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ddof: int = 0 +) -> tf.Tensor: + with tf.device(a.device): + assert ddof in {0, 1} + tf_var = tf.math.reduce_variance(a, axis=axis, keepdims=keepdims) + if ddof: + n = tf.shape(a)[axis] if axis is not None else tf.size(a) + tf_var *= float(n) / float(n - 1) + return tf_var + + +@numeric.size.register(tf.Tensor) +def _(a: tf.Tensor) -> int: + return tf.size(a) + + +@numeric.matmul.register(tf.Tensor) +def _(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: + with tf.device(x1.device): + return tf.matmul(x1, x2) + + +@numeric.unsqueeze.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: + with tf.device(a.device): + return tf.expand_dims(a, axis=axis) + + +@numeric.transpose.register(tf.Tensor) +def _(a: tf.Tensor, axes: Optional[Tuple[int, ...]] = None) -> tf.Tensor: + with tf.device(a.device): + return tf.transpose(a, perm=axes) + + +@numeric.argsort.register(tf.Tensor) +def _(a: tf.Tensor, axis: int = -1, descending=False, stable=False) -> tf.Tensor: + with tf.device(a.device): + direction = "DESCENDING" if descending else "ASCENDING" + return tf.argsort(a, axis=axis, direction=direction, stable=stable) + + +@numeric.diag.register(tf.Tensor) +def _(a: tf.Tensor, k: int = 0) -> tf.Tensor: + with tf.device(a.device): + if a._rank() == 2: + if k == 0: + return tf.linalg.diag_part(a) + elif k > 0: + return tf.linalg.diag_part(a[:, k:]) + else: + return tf.linalg.diag_part(a[-k:, :]) + + if a._rank() == 1: + return tf.linalg.diag(a, k=k) + + +@numeric.logical_or.register(tf.Tensor) +def _(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: + with tf.device(x1.device): + return tf.logical_or(x1, x2) + + +@numeric.masked_mean.register(tf.Tensor) +def _( + x: tf.Tensor, mask: Optional[tf.Tensor], axis: Union[int, Tuple[int, ...], List[int]], keepdims=False +) -> tf.Tensor: + with tf.device(x.device): + if mask is None: + return tf.reduce_mean(x, axis=axis, keepdims=keepdims) + flipped_mask = ~mask + valid_counts = tf.reduce_sum(tf.cast(flipped_mask, x.dtype), axis=axis, keepdims=keepdims) + masked_x = tf.where(mask, tf.zeros_like(x), x) + valid_sum = tf.reduce_sum(masked_x, axis=axis, keepdims=keepdims) + + ret = valid_sum / valid_counts + ret = tf.where(tf.math.is_nan(ret), tf.zeros_like(ret), ret) + + return ret + + +@numeric.masked_median.register(tf.Tensor) +def _( + x: tf.Tensor, mask: Optional[tf.Tensor], axis: Union[int, Tuple[int, ...], List[int]], keepdims=False +) -> tf.Tensor: + if mask is None: + return numeric.median(x, axis=axis, keepdims=keepdims) + + masked_x = tf.where(mask, np.nan, x) + np_masked_x = masked_x.numpy() + np_masked_median = np.nanquantile(np_masked_x, 0.5, axis=axis, keepdims=keepdims) + + with tf.device(x.device): + ret = tf.constant(np_masked_median) + ret = tf.where(tf.math.is_nan(ret), tf.zeros_like(ret), ret) + + return ret + + +@numeric.expand_dims.register(tf.Tensor) +def _(a: tf.Tensor, axis: Union[int, Tuple[int, ...], List[int]]) -> np.ndarray: + if type(axis) not in (tuple, list): + axis = (axis,) + + if len(set(axis)) != len(axis): + raise ValueError("repeated axis") + + out_ndim = len(axis) + a.ndim + + norm_axis = [] + for ax in axis: + if ax < -out_ndim or ax >= out_ndim: + raise ValueError(f"axis {ax} is out of bounds for array of dimension {out_ndim}") + norm_axis.append(ax + out_ndim if ax < 0 else ax) + + shape_it = iter(a.shape) + shape = [1 if ax in norm_axis else next(shape_it) for ax in range(out_ndim)] + return tf.reshape(a, shape) + + +@numeric.clone.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.identity(a) + + +@numeric.searchsorted.register(tf.Tensor) +def _(a: tf.Tensor, v: tf.Tensor, side: str = "left", sorter: Optional[tf.Tensor] = None) -> tf.Tensor: + if side not in ["right", "left"]: + raise ValueError(f"Invalid value for 'side': {side}. Expected 'right' or 'left'.") + if a.ndim != 1: + raise ValueError(f"Input tensor 'a' must be 1-D. Received {a.ndim}-D tensor.") + sorted_a = tf.sort(a) + return tf.searchsorted(sorted_sequence=sorted_a, values=v, side=side) + + +def zeros( + shape: Tuple[int, ...], + *, + dtype: Optional[TensorDataType] = None, + device: Optional[TensorDeviceType] = None, +) -> tf.Tensor: + if dtype is not None: + dtype = DTYPE_MAP[dtype] + if device is not None: + device = DEVICE_MAP[device] + with tf.device(device): + zeros = tf.zeros(shape, dtype=dtype) + return zeros + + +def eye( + n: int, + m: Optional[int] = None, + *, + dtype: Optional[TensorDataType] = None, + device: Optional[TensorDeviceType] = None, +) -> tf.Tensor: + if dtype is not None: + dtype = DTYPE_MAP[dtype] + if device is not None: + device = DEVICE_MAP[device] + p_args = (n,) if m is None else (n, m) + with tf.device(device): + return tf.eye(*p_args, dtype=dtype) + + +def arange( + start: float, + end: float, + step: float, + *, + dtype: Optional[TensorDataType] = None, + device: Optional[TensorDeviceType] = None, +) -> tf.Tensor: + if dtype is not None: + dtype = DTYPE_MAP[dtype] + if device is not None: + device = DEVICE_MAP[device] + with tf.device(device): + r = tf.range(start, end, step, dtype=dtype) + return r + + +def from_numpy(ndarray: np.ndarray) -> tf.Tensor: + with tf.device("CPU"): + return tf.constant(ndarray) + + +@numeric.log2.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.math.log(a) / tf.math.log(2.0) + + +@numeric.ceil.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.math.ceil(a) diff --git a/nncf/tensor/tensor.py b/nncf/tensor/tensor.py index 9ccd841b164..998b3dc2dfb 100644 --- a/nncf/tensor/tensor.py +++ b/nncf/tensor/tensor.py @@ -140,7 +140,7 @@ def __floordiv__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: def __rfloordiv__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: return cast(Tensor, _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv)) - def __ifloordiv__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: + def __ifloordiv__(self, other: Union[Tensor, float]) -> Tensor: self._data //= unwrap_tensor_data(other) return self diff --git a/nncf/version.py b/nncf/version.py index f7b5b2206e3..26ec55fa1da 100644 --- a/nncf/version.py +++ b/nncf/version.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.16.0" +__version__ = "2.16.0.dev0+b1af5d11dirty" BKC_TORCH_SPEC = "==2.6.*" diff --git a/tests/cross_fw/test_templates/template_test_nncf_tensor.py b/tests/cross_fw/test_templates/template_test_nncf_tensor.py index 8a2f54a03af..7bf600ada7a 100644 --- a/tests/cross_fw/test_templates/template_test_nncf_tensor.py +++ b/tests/cross_fw/test_templates/template_test_nncf_tensor.py @@ -113,7 +113,8 @@ def test_operators_tensor(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) - assert res_nncf.device == nncf_tensor_a.device + if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU): + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", OPERATOR_MAP.keys()) def test_operators_int(self, op_name): @@ -129,7 +130,8 @@ def test_operators_int(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) - assert res_nncf.device == nncf_tensor_a.device + if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU): + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", BINARY_OPERATORS) def test_operators_int_rev(self, op_name): @@ -145,7 +147,11 @@ def test_operators_int_rev(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) - assert res_nncf.device == nncf_tensor_a.device + if not ( + (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU) + or (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.GPU and op_name == "pow") + ): + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", COMPARISON_OPERATOR_MAP.keys()) def test_comparison_tensor(self, op_name): @@ -159,7 +165,7 @@ def test_comparison_tensor(self, op_name): res = fn(tensor_a, tensor_b) res_nncf = fn(nncf_tensor_a, nncf_tensor_b) - assert res == res_nncf + assert res_nncf == res assert isinstance(res_nncf, Tensor) @pytest.mark.parametrize("op_name", COMPARISON_OPERATOR_MAP.keys()) @@ -173,7 +179,7 @@ def test_comparison_int(self, op_name): res = fn(tensor_a, value) res_nncf = fn(nncf_tensor_a, value) - assert res == res_nncf + assert res_nncf == res assert isinstance(res_nncf, Tensor) @pytest.mark.parametrize("op_name", COMPARISON_OPERATOR_MAP.keys()) @@ -187,7 +193,7 @@ def test_comparison_int_rev(self, op_name): res = fn(value, tensor_a) res_nncf = fn(value, nncf_tensor_a) - assert res == res_nncf + assert res_nncf == res assert isinstance(res_nncf, Tensor) @pytest.mark.parametrize( @@ -390,7 +396,8 @@ def test_getitem_for_index(self): res = nncf_tensor[1] assert res == 1 assert isinstance(res, Tensor) - assert res.device == nncf_tensor.device + if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU): + assert res.device == nncf_tensor.device @pytest.mark.parametrize("is_tensor_indecies", (False, True)) def test_getitem_for_indecies(self, is_tensor_indecies): @@ -527,7 +534,8 @@ def test_fn_where(self): res = fns.where(tensor > 0, 1, 0) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) - assert res.device == tensor.device + if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU): + assert res.device == tensor.device @pytest.mark.parametrize( "val, ref", @@ -1101,7 +1109,8 @@ def test_fn_matmul(self, m1, m2, ref): assert isinstance(res, Tensor) assert fns.allclose(res.data, ref_tensor) - assert res.device == tensor1.device + if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU): + assert res.device == tensor1.device @pytest.mark.parametrize( "val, axis, ref", @@ -1544,6 +1553,8 @@ def test_fn_eye(self, n, m, ref): ] ): continue + if (not dtype.is_float()) and self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.GPU: + continue tensor_a = fns.eye(n, m, backend=self.backend(), dtype=dtype, device=self.device()) assert isinstance(tensor_a, Tensor) assert tensor_a.device == self.device() @@ -1563,18 +1574,20 @@ def test_fn_arange(self, start, end, stop, ref): args.append(end) if stop is not None: args.append(stop) - ref = Tensor(self.to_tensor(ref)) + for dtype in [TensorDataType.int32, TensorDataType.float32]: + tensor_ref = Tensor(fns.astype(self.to_tensor(ref), dtype)) tensor_a = fns.arange(*tuple(args), backend=self.backend(), dtype=dtype, device=self.device()) assert isinstance(tensor_a, Tensor) assert tensor_a.device == self.device() assert tensor_a.backend == self.backend() assert tensor_a.dtype == dtype - assert fns.all(tensor_a == ref) + assert fns.all(tensor_a == tensor_ref) def test_fn_from_numpy(self): ndarray = np.array([1, 2]) - ref = Tensor(self.to_cpu(self.to_tensor(ndarray))) + ref_cpu = self.to_cpu(self.to_tensor(ndarray)) + ref = Tensor(ref_cpu) tensor = fns.from_numpy(ndarray, backend=ref.backend) assert isinstance(tensor, Tensor) assert tensor.device == ref.device diff --git a/tests/tensorflow/test_tensor.py b/tests/tensorflow/test_tensor.py new file mode 100644 index 00000000000..be0c83a1d03 --- /dev/null +++ b/tests/tensorflow/test_tensor.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import tensorflow as tf + +from nncf.tensor import Tensor +from nncf.tensor import TensorDataType +from nncf.tensor.definitions import TensorBackend +from nncf.tensor.definitions import TensorDeviceType +from nncf.tensor.functions import linalg +from tests.cross_fw.test_templates.template_test_nncf_tensor import TemplateTestNNCFTensorOperators + + +def cast_to(x: tf.Tensor, dtype: TensorDataType) -> tf.Tensor: + if dtype is TensorDataType.float32: + return tf.cast(x, tf.float32) + if dtype is TensorDataType.float16: + return tf.cast(x, tf.float16) + raise NotImplementedError + + +class TestTFNNCFTensorOperators(TemplateTestNNCFTensorOperators): + @staticmethod + def to_tensor(x): + with tf.device("/CPU:0"): + return tf.constant(x) + + @staticmethod + def to_cpu(x): + return x + + @staticmethod + def cast_to(x: tf.Tensor, dtype: TensorDataType) -> tf.Tensor: + return cast_to(x, dtype) + + @staticmethod + def backend() -> TensorBackend: + return TensorBackend.tf + + @staticmethod + def device() -> TensorDeviceType: + return TensorDeviceType.CPU + + def test_norm_keepdims(self): + tensor_data = [[1.0, 2.0], [3.0, 4.0]] + tf_tensor = self.to_tensor(tensor_data) + tensor = Tensor(tf_tensor) + + result = linalg.norm(tensor, ord="nuc", keepdims=True) + + assert result.shape == (1, 1) + + for ord_val in [None, 0, 1, 2, -1, -2, "fro"]: + result = linalg.norm(tensor, ord=ord_val, keepdims=True) + assert result.shape == (1, 1), f"Failed for ord={ord_val}" + + def test_lstsq_rank2(self): + x_data = [1.0, 2.0, 4.0] + ones_data = [1.0, 1.0, 1.0] + a_data = [[x_data[0], ones_data[0]], [x_data[1], ones_data[1]], [x_data[2], ones_data[2]]] + + a_tensor = self.to_tensor(a_data) + a = Tensor(a_tensor) + + b_data = [[6.0, 6.0], [8.0, 10.0], [12.0, 18.0]] + b_tensor = self.to_tensor(b_data) + b = Tensor(b_tensor) + + x = linalg.lstsq(a, b) + + assert x.shape == (2, 2) + + expected = [[2.0, 4.0], [4.0, 2.0]] + + for i in range(2): + for j in range(2): + x_val = x.data.numpy()[i][j] + expected_val = expected[i][j] + assert abs(x_val - expected_val) < 0.2, f"Value at ({i},{j}) is {x_val}, expected {expected_val}" + + @pytest.mark.skip("Desired slicing is not supported for TensorFlow") + @pytest.mark.parametrize("is_tensor_indecies", (False, True)) + def test_getitem_for_indecies(self, is_tensor_indecies): + pass + + @pytest.mark.skip("TensorFlow throws different kind of exceptions") + @pytest.mark.parametrize( + "val, axis, exception_type, exception_match", + ( + ([[[[1], [2]], [[1], [2]]]], (0, 1), ValueError, "not equal to one"), + ([[[[1], [2]], [[1], [2]]]], 42, IndexError, "out of"), + ([[[[1], [2]], [[1], [2]]]], (0, 42), IndexError, "out of"), + ), + ) + def test_squeeze_axis_error(self, val, axis, exception_type, exception_match): + pass + + +@pytest.mark.skipif(len(tf.config.list_physical_devices("GPU")) == 0, reason="Skipping for CPU-only setups") +class TestGPUTFNNCFTensorOperators(TemplateTestNNCFTensorOperators): + @staticmethod + def to_tensor(x): + with tf.device("GPU"): + return tf.constant(x) + + @staticmethod + def to_cpu(x): + with tf.device("CPU"): + return tf.constant(x.numpy()) + + @staticmethod + def cast_to(x: tf.Tensor, dtype: TensorDataType) -> tf.Tensor: + return cast_to(x, dtype) + + def test_device(self): + tensor = Tensor(self.to_tensor([1])) + assert tensor.device == TensorDeviceType.GPU + + def test_norm_keepdims(self): + tensor_data = [[1.0, 2.0], [3.0, 4.0]] + tf_tensor = self.to_tensor(tensor_data) + tensor = Tensor(tf_tensor) + + result = linalg.norm(tensor, ord="nuc", keepdims=True) + + assert result.shape == (1, 1) + + for ord_val in [None, 0, 1, 2, -1, -2, "fro"]: + result = linalg.norm(tensor, ord=ord_val, keepdims=True) + assert result.shape == (1, 1), f"Failed for ord={ord_val}" + + def test_lstsq_rank2(self): + x_data = [1.0, 2.0, 4.0] + ones_data = [1.0, 1.0, 1.0] + a_data = [[x_data[0], ones_data[0]], [x_data[1], ones_data[1]], [x_data[2], ones_data[2]]] + + a_tensor = self.to_tensor(a_data) + a = Tensor(a_tensor) + + b_data = [[6.0, 6.0], [8.0, 10.0], [12.0, 18.0]] + b_tensor = self.to_tensor(b_data) + b = Tensor(b_tensor) + + x = linalg.lstsq(a, b) + + assert x.shape == (2, 2) + + expected = [[2.0, 4.0], [4.0, 2.0]] + + for i in range(2): + for j in range(2): + x_val = x.data.numpy()[i][j] + expected_val = expected[i][j] + assert abs(x_val - expected_val) < 0.2, f"Value at ({i},{j}) is {x_val}, expected {expected_val}" + + @staticmethod + def backend() -> TensorBackend: + return TensorBackend.tf + + @staticmethod + def device() -> TensorDeviceType: + return TensorDeviceType.GPU + + @pytest.mark.skip("Desired slicing is not supported for TensorFlow") + @pytest.mark.parametrize("is_tensor_indecies", (False, True)) + def test_getitem_for_indecies(self, is_tensor_indecies): + pass + + @pytest.mark.skip("TensorFlow throws different kind of exceptions") + @pytest.mark.parametrize( + "val, axis, exception_type, exception_match", + ( + ([[[[1], [2]], [[1], [2]]]], (0, 1), ValueError, "not equal to one"), + ([[[[1], [2]], [[1], [2]]]], 42, IndexError, "out of"), + ([[[[1], [2]], [[1], [2]]]], (0, 42), IndexError, "out of"), + ), + ) + def test_squeeze_axis_error(self, val, axis, exception_type, exception_match): + pass From 70fa4c5e51094faab82cc6c2e753f6c4922f9699 Mon Sep 17 00:00:00 2001 From: "Kruglov, Oleg" Date: Tue, 10 Dec 2024 17:10:56 -0800 Subject: [PATCH 02/17] Address comments --- nncf/tensor/functions/tf_numeric.py | 29 ++++++++++------------------- tests/tensorflow/test_tensor.py | 2 +- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index 4e3afd29c53..131a0390516 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -39,7 +39,10 @@ @numeric.device.register(tf.Tensor) def _(a: tf.Tensor) -> TensorDeviceType: - return DEVICE_MAP_REV[a.device.split("/")[-1].split(":")[1]] + if "CPU" in a.device: + return DEVICE_MAP_REV["CPU"] + if "GPU" in a.device: + return DEVICE_MAP_REV["GPU"] @numeric.backend.register(tf.Tensor) @@ -136,7 +139,7 @@ def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Te @numeric.isempty.register(tf.Tensor) def _(a: tf.Tensor) -> bool: - return bool(tf.equal(tf.size(a), 0).numpy().T) + return bool(tf.equal(tf.size(a), 0).numpy()) @numeric.isclose.register(tf.Tensor) @@ -199,18 +202,8 @@ def _(x: tf.Tensor, axis: int = 0) -> List[tf.Tensor]: @numeric.moveaxis.register(tf.Tensor) def _(a: tf.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> tf.Tensor: - perm = list(range(a._rank())) - if isinstance(source, int): - axe_to_move = perm.pop(source) - if destination < 0: - destination = len(perm) + destination + 1 - perm.insert(destination, axe_to_move) - else: - old_perm = perm[:] - for i in range(len(source)): - perm[destination[i]] = old_perm[source[i]] with tf.device(a.device): - return tf.transpose(a, perm) + return tf.experimental.numpy.moveaxis(a, source, destination) @numeric.mean.register(tf.Tensor) @@ -311,6 +304,7 @@ def _(a: tf.Tensor, data: Any) -> tf.Tensor: @numeric.item.register(tf.Tensor) def _(a: tf.Tensor) -> Union[int, float, bool]: + a = tf.reshape(a, []) np_item = a.numpy() if isinstance(np_item, np.floating): return float(np_item) @@ -337,11 +331,10 @@ def _( a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ddof: int = 0 ) -> tf.Tensor: with tf.device(a.device): - assert ddof in {0, 1} tf_var = tf.math.reduce_variance(a, axis=axis, keepdims=keepdims) if ddof: n = tf.shape(a)[axis] if axis is not None else tf.size(a) - tf_var *= float(n) / float(n - 1) + tf_var *= float(n) / float(n - ddof) return tf_var @@ -480,8 +473,7 @@ def zeros( if device is not None: device = DEVICE_MAP[device] with tf.device(device): - zeros = tf.zeros(shape, dtype=dtype) - return zeros + return tf.zeros(shape, dtype=dtype) def eye( @@ -513,8 +505,7 @@ def arange( if device is not None: device = DEVICE_MAP[device] with tf.device(device): - r = tf.range(start, end, step, dtype=dtype) - return r + return tf.range(start, end, step, dtype=dtype) def from_numpy(ndarray: np.ndarray) -> tf.Tensor: diff --git a/tests/tensorflow/test_tensor.py b/tests/tensorflow/test_tensor.py index be0c83a1d03..7169d31c0b6 100644 --- a/tests/tensorflow/test_tensor.py +++ b/tests/tensorflow/test_tensor.py @@ -30,7 +30,7 @@ def cast_to(x: tf.Tensor, dtype: TensorDataType) -> tf.Tensor: class TestTFNNCFTensorOperators(TemplateTestNNCFTensorOperators): @staticmethod def to_tensor(x): - with tf.device("/CPU:0"): + with tf.device("CPU"): return tf.constant(x) @staticmethod From 2cdd5eaa4b22107452f447f5864477c2a4b95f49 Mon Sep 17 00:00:00 2001 From: "Kruglov, Oleg" Date: Wed, 18 Dec 2024 18:39:27 -0800 Subject: [PATCH 03/17] Address part of comments --- nncf/tensor/functions/tf_linalg.py | 51 +++++++++++++++---- nncf/tensor/functions/tf_numeric.py | 28 +++------- .../template_test_nncf_tensor.py | 51 +++++++++++++++++++ 3 files changed, 100 insertions(+), 30 deletions(-) diff --git a/nncf/tensor/functions/tf_linalg.py b/nncf/tensor/functions/tf_linalg.py index a13e8f48b05..35b5a472d31 100644 --- a/nncf/tensor/functions/tf_linalg.py +++ b/nncf/tensor/functions/tf_linalg.py @@ -24,17 +24,18 @@ def _( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> tf.Tensor: - if axis is None: - axis = 0 if a._rank() == 1 else (0, 1) - - if ord is None or (a._rank() == 1 and ord == "fro"): + if ord is None: ord = "euclidean" + rank = tf.rank(a) + if rank == 2 and axis is None: + axis = (0, 1) with tf.device(a.device): - if ord == "nuc": - s, _, _ = tf.linalg.svd(a) - result = tf.reduce_sum(s) - + if ord == "nuc" and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord='nuc' is only supported for 2D tensors") + s = tf.linalg.svd(a, compute_uv=False) + result = tf.reduce_sum(s, axis=-1) if keepdims: result_shape = [1 if i in axis else dim for i, dim in enumerate(a.shape)] result = tf.reshape(result, result_shape) @@ -43,6 +44,38 @@ def _( if ord == 0: return tf.cast(tf.math.count_nonzero(a, axis=axis, keepdims=keepdims), a.dtype) + if ord == -1 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=-1 is only supported for 2D tensors") + return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) + + if ord == 1 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=1 is only supported for 2D tensors") + return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) + + if ord == -2 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=-2 is only supported for 2D tensors") + s = tf.linalg.svd(a, compute_uv=False) + return tf.reduce_min(s, axis=-1) + + if ord == 2 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=2 is only supported for 2D tensors") + s = tf.linalg.svd(a, compute_uv=False) + return tf.reduce_max(s, axis=-1) + + if ord == float("inf") and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=inf is only supported for 2D tensors") + return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) + + if ord == -float("inf") and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=-inf is only supported for 2D tensors") + return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) + return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims) @@ -97,4 +130,4 @@ def _(a: tf.Tensor, full_matrices: Optional[bool] = True) -> tf.Tensor: with tf.device(a.device): s, u, v = tf.linalg.svd(a, full_matrices=full_matrices) - return u, s, tf.transpose(v) + return u, s, tf.transpose(v, conjugate=True) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index 131a0390516..ed436556d7f 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -53,10 +53,6 @@ def _(a: tf.Tensor) -> TensorBackend: @numeric.squeeze.register(tf.Tensor) def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: with tf.device(a.device): - if axis is None: - return tf.squeeze(a) - if isinstance(axis, Tuple) and any(a.shape[i] != 1 for i in axis): - raise ValueError("Cannot select an axis to squeeze out which has size not equal to one") return tf.squeeze(a, axis) @@ -67,19 +63,15 @@ def _(a: tf.Tensor) -> tf.Tensor: @numeric.max.register(tf.Tensor) -def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> tf.Tensor: +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor: with tf.device(a.device): - if axis is None: - return tf.reduce_max(a) - return tf.reduce_max(a, axis=axis, keepdims=keepdim) + return tf.reduce_max(a, axis=axis, keepdims=keepdims) @numeric.min.register(tf.Tensor) -def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> tf.Tensor: +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor: with tf.device(a.device): - if axis is None: - return tf.reduce_min(a) - return tf.reduce_min(a, axis=axis, keepdims=keepdim) + return tf.reduce_min(a, axis=axis, keepdims=keepdims) @numeric.abs.register(tf.Tensor) @@ -139,7 +131,7 @@ def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Te @numeric.isempty.register(tf.Tensor) def _(a: tf.Tensor) -> bool: - return bool(tf.equal(tf.size(a), 0).numpy()) + return bool(tf.equal(tf.size(a), 0)) @numeric.isclose.register(tf.Tensor) @@ -214,6 +206,7 @@ def _( dtype: Optional[TensorDataType] = None, ) -> tf.Tensor: with tf.device(a.device): + a = tf.cast(a, DTYPE_MAP[dtype]) if dtype is not None else a return tf.reduce_mean(a, axis=axis, keepdims=keepdims) @@ -304,14 +297,7 @@ def _(a: tf.Tensor, data: Any) -> tf.Tensor: @numeric.item.register(tf.Tensor) def _(a: tf.Tensor) -> Union[int, float, bool]: - a = tf.reshape(a, []) - np_item = a.numpy() - if isinstance(np_item, np.floating): - return float(np_item) - if isinstance(np_item, np.bool_): - return bool(np_item) - - return int(np_item) + return a.numpy().item() @numeric.sum.register(tf.Tensor) diff --git a/tests/cross_fw/test_templates/template_test_nncf_tensor.py b/tests/cross_fw/test_templates/template_test_nncf_tensor.py index 7bf600ada7a..1e68a329305 100644 --- a/tests/cross_fw/test_templates/template_test_nncf_tensor.py +++ b/tests/cross_fw/test_templates/template_test_nncf_tensor.py @@ -812,6 +812,7 @@ def test_fn_median(self, x, axis, keepdims, ref): (1.1, 0, 1.0), ([1.1, 0.9], 0, [1.0, 1.0]), ([1.11, 0.91], 1, [1.1, 0.9]), + ([5.5, 3.3], -1, [10.0, 0.0]), ), ) def test_fn_round(self, val, decimals, ref): @@ -1053,6 +1054,13 @@ def test_fn_var(self, x, axis, keepdims, ddof, ref): True, [[1.53063197]], ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + "nuc", + (0, 1), + False, + [1.53063197], + ), ( [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], float("inf"), @@ -1067,6 +1075,49 @@ def test_fn_var(self, x, axis, keepdims, ddof, ref): False, 0.9364634205074938, ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 2, + 0, + False, + [0.8062258, 0.72801095, 0.22360681], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 1, + None, + False, + 0.9, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + -1, + None, + False, + 0.3, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + -2, + None, + False, + 0.59416854, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + float("inf"), + None, + False, + 1.2, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + -float("inf"), + None, + False, + 0.9, + ), + ([[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], None, None, False, 2.82842708), ), ) def test_fn_linalg_norm(self, x, ord, axis, keepdims, ref): From 6a57f31ae501e93ffaa6e0ebe529c471c1020155 Mon Sep 17 00:00:00 2001 From: "Kruglov, Oleg" Date: Thu, 19 Dec 2024 20:02:27 -0800 Subject: [PATCH 04/17] Add updates from develop --- docs/api/source/conf.py | 1 + nncf/tensor/functions/__init__.py | 1 + nncf/tensor/functions/dispatcher.py | 4 ++ nncf/tensor/functions/tf_io.py | 28 ++++++++++++++ nncf/tensor/functions/tf_numeric.py | 38 ++++++++++++++----- .../template_test_nncf_tensor.py | 7 ++++ 6 files changed, 69 insertions(+), 10 deletions(-) create mode 100644 nncf/tensor/functions/tf_io.py diff --git a/docs/api/source/conf.py b/docs/api/source/conf.py index e008c1f704d..b7133eddf0a 100644 --- a/docs/api/source/conf.py +++ b/docs/api/source/conf.py @@ -146,6 +146,7 @@ def collect_api_entities() -> APIInfo: "nncf.tensor.functions.torch_io", "nncf.tensor.functions.numpy_io", "nncf.tensor.functions.tf_numeric", + "nncf.tensor.functions.tf_io", "nncf.tensor.functions.tf_linalg", ] diff --git a/nncf/tensor/functions/__init__.py b/nncf/tensor/functions/__init__.py index c5d21dbc5d2..8af96e3ed66 100644 --- a/nncf/tensor/functions/__init__.py +++ b/nncf/tensor/functions/__init__.py @@ -75,6 +75,7 @@ def _initialize_backends() -> None: import nncf.tensor.functions.numpy_numeric with contextlib.suppress(ImportError): + import nncf.tensor.functions.tf_io import nncf.tensor.functions.tf_linalg import nncf.tensor.functions.tf_numeric diff --git a/nncf/tensor/functions/dispatcher.py b/nncf/tensor/functions/dispatcher.py index c97574cc596..c7ce62448f0 100644 --- a/nncf/tensor/functions/dispatcher.py +++ b/nncf/tensor/functions/dispatcher.py @@ -305,6 +305,10 @@ def get_io_backend_fn(fn_name: str, backend: TensorBackend) -> Callable[..., Any from nncf.tensor.functions import numpy_io return getattr(numpy_io, fn_name) + if backend == TensorBackend.tf: + from nncf.tensor.functions import tf_io + + return getattr(tf_io, fn_name) if backend == TensorBackend.torch: from nncf.tensor.functions import torch_io diff --git a/nncf/tensor/functions/tf_io.py b/nncf/tensor/functions/tf_io.py new file mode 100644 index 00000000000..bf97ab43121 --- /dev/null +++ b/nncf/tensor/functions/tf_io.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import tensorflow as tf +from safetensors.tensorflow import load_file as tf_load_file +from safetensors.tensorflow import save_file as tf_save_file + +from nncf.tensor import TensorDeviceType +from nncf.tensor.functions import io as io + + +def load_file(file_path: str, *, device: Optional[TensorDeviceType] = None) -> Dict[str, tf.Tensor]: + return tf_load_file(file_path) + + +@io.save_file.register(tf.Tensor) +def _(data: Dict[str, tf.Tensor], file_path: str) -> None: + return tf_save_file(data, file_path) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index ed436556d7f..c8874a066f3 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import numpy as np import tensorflow as tf @@ -19,6 +19,7 @@ from nncf.tensor.definitions import TensorBackend from nncf.tensor.definitions import TypeInfo from nncf.tensor.functions import numeric as numeric +from nncf.tensor.tensor import TTensor DTYPE_MAP = { TensorDataType.float16: tf.float16, @@ -37,6 +38,14 @@ DEVICE_MAP_REV = {v: k for k, v in DEVICE_MAP.items()} +def convert_to_tf_device(device: TensorDeviceType) -> str: + return DEVICE_MAP[device] if device is not None else None + + +def convert_to_tf_dtype(dtype: TensorDataType) -> tf.DType: + return DTYPE_MAP[dtype] if dtype is not None else None + + @numeric.device.register(tf.Tensor) def _(a: tf.Tensor) -> TensorDeviceType: if "CPU" in a.device: @@ -357,16 +366,13 @@ def _(a: tf.Tensor, axis: int = -1, descending=False, stable=False) -> tf.Tensor @numeric.diag.register(tf.Tensor) def _(a: tf.Tensor, k: int = 0) -> tf.Tensor: with tf.device(a.device): - if a._rank() == 2: - if k == 0: - return tf.linalg.diag_part(a) - elif k > 0: - return tf.linalg.diag_part(a[:, k:]) - else: - return tf.linalg.diag_part(a[-k:, :]) - - if a._rank() == 1: + rank = tf.rank(a) + if rank == 1: return tf.linalg.diag(a, k=k) + elif rank == 2: + return tf.linalg.diag_part(a, k=k) + else: + raise ValueError("Input tensor must be 1D or 2D.") @numeric.logical_or.register(tf.Tensor) @@ -509,3 +515,15 @@ def _(a: tf.Tensor) -> tf.Tensor: def _(a: tf.Tensor) -> tf.Tensor: with tf.device(a.device): return tf.math.ceil(a) + + +def tensor( + data: Union[TTensor, Sequence[float]], + *, + dtype: Optional[TensorDataType] = None, + device: Optional[TensorDeviceType] = None, +) -> tf.Tensor: + device = convert_to_tf_device(device) + dtype = convert_to_tf_dtype(dtype) + with tf.device(device): + return tf.constant(data, dtype=dtype) diff --git a/tests/cross_fw/test_templates/template_test_nncf_tensor.py b/tests/cross_fw/test_templates/template_test_nncf_tensor.py index 1e68a329305..e6d0e79b540 100644 --- a/tests/cross_fw/test_templates/template_test_nncf_tensor.py +++ b/tests/cross_fw/test_templates/template_test_nncf_tensor.py @@ -1807,6 +1807,13 @@ def test_save_load_symlink_error(self, tmp_path): @pytest.mark.parametrize("data", [[3.0, 2.0, 2.0], [1, 2, 3]]) @pytest.mark.parametrize("dtype", [TensorDataType.float32, TensorDataType.int32, TensorDataType.uint8, None]) def test_fn_tensor(self, data, dtype): + if ( + self.backend() == TensorBackend.tf + and dtype is not None + and not dtype.is_float() + and (data == [3.0, 2.0, 2.0]) + ): + pytest.skip("TF backend does not support non-float dtypes for float data") nncf_tensor = fns.tensor(data, backend=self.backend(), dtype=dtype, device=self.device()) backend_tensor = Tensor(self.to_tensor(data)) if dtype is not None: From 20fe39cb7f0cc0d64a6ed874b6442dd98f7bd8c0 Mon Sep 17 00:00:00 2001 From: "Kruglov, Oleg" Date: Sun, 22 Dec 2024 16:32:14 -0800 Subject: [PATCH 05/17] Address comments --- nncf/tensor/functions/tf_numeric.py | 34 ++++++++++++++----- .../template_test_nncf_tensor.py | 22 +++++++----- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index c8874a066f3..e1a61437ed0 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -119,9 +119,7 @@ def _( a: tf.Tensor, b: Union[tf.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False ) -> bool: with tf.device(a.device): - if not isinstance(b, tf.Tensor): - b = tf.constant(b) - return tf.experimental.numpy.allclose(a, tf.cast(b, a.dtype), rtol=rtol, atol=atol, equal_nan=equal_nan) + return bool(tf.experimental.numpy.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) @numeric.any.register(tf.Tensor) @@ -148,7 +146,7 @@ def _( a: tf.Tensor, b: Union[tf.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False ) -> tf.Tensor: with tf.device(a.device): - return tf.experimental.numpy.isclose(a, tf.cast(b, a.dtype), atol=atol, rtol=rtol, equal_nan=equal_nan) + return tf.experimental.numpy.isclose(a, b, atol=atol, rtol=rtol, equal_nan=equal_nan) @numeric.maximum.register(tf.Tensor) @@ -225,13 +223,31 @@ def _( axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False, ) -> tf.Tensor: - numpy_a = np.array(a) - numpy_median = np.median(numpy_a, axis=axis, keepdims=keepdims) - with tf.device(a.device): - tf_median = tf.constant(numpy_median) + if axis is None: + a = tf.reshape(a, [-1]) + else: + if isinstance(axis, int): + axis = (axis,) + destination_axis = tuple([-(i + 1) for i in range(len(axis))]) + a = tf.experimental.numpy.moveaxis(a, axis, destination_axis) + last_axis = 1 + for i in range(len(axis)): + last_axis *= a.shape[-(i + 1)] + new_shape = a.shape[: -len(axis)] + [last_axis] + a = tf.reshape(a, new_shape) + k = 1 + a.shape[-1] // 2 + top_k = tf.math.top_k(a, k=k, sorted=True).values + if a.shape[-1] % 2 == 0: + median = (tf.gather(top_k, indices=[k - 2], axis=-1) + tf.gather(top_k, indices=[k - 1], axis=-1)) / 2 + else: + median = tf.gather(top_k, indices=[k - 1], axis=-1) + median = tf.squeeze(median, axis=-1) + if keepdims: + for axe in sorted(axis, key=lambda x: abs(x)): + median = tf.expand_dims(median, axe) - return tf_median + return median @numeric.round.register(tf.Tensor) diff --git a/tests/cross_fw/test_templates/template_test_nncf_tensor.py b/tests/cross_fw/test_templates/template_test_nncf_tensor.py index e6d0e79b540..fca455f3b18 100644 --- a/tests/cross_fw/test_templates/template_test_nncf_tensor.py +++ b/tests/cross_fw/test_templates/template_test_nncf_tensor.py @@ -566,19 +566,23 @@ def test_isempty(self, val, ref): assert isinstance(res, bool) @pytest.mark.parametrize( - "x1, x2, rtol, atol, ref", + "x1, x2, is_tensor, rtol, atol, ref", ( - ([0.1], [0.1], None, None, True), - ([0.1], [0.10001], None, None, False), - ([0.1], [0.10001], 0.1, None, True), - ([0.1], [0.10001], None, 0.1, True), - ([0.1], [0.20001], None, 0.1, False), - ([0.1], 0.1, None, None, True), + ([0.1], [0.1], True, None, None, True), + ([0.1], [0.10001], True, None, None, False), + ([0.1], [0.10001], True, 0.1, None, True), + ([0.1], [0.10001], True, None, 0.1, True), + ([0.1], [0.20001], True, None, 0.1, False), + ([0.1], 0.1, True, None, None, True), + ([0.1], 0.1, False, None, None, True), ), ) - def test_fn_allclose(self, x1, x2, rtol, atol, ref): + def test_fn_allclose(self, x1, x2, is_tensor, rtol, atol, ref): tensor1 = Tensor(self.to_tensor(x1)) - tensor2 = Tensor(self.to_tensor(x2)) + if is_tensor: + tensor2 = Tensor(self.to_tensor(x2)) + else: + tensor2 = x2 if rtol is not None: res = fns.allclose(tensor1, tensor2, rtol=rtol) elif atol is not None: From fb347ab112d6af8cd923fb8b6e1cdeaef71da3b6 Mon Sep 17 00:00:00 2001 From: darshil929 Date: Sat, 15 Mar 2025 00:54:08 +0530 Subject: [PATCH 06/17] fix tf_io.py save_file registration decorator --- nncf/tensor/functions/tf_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nncf/tensor/functions/tf_io.py b/nncf/tensor/functions/tf_io.py index bf97ab43121..6ab5d7e4b0f 100644 --- a/nncf/tensor/functions/tf_io.py +++ b/nncf/tensor/functions/tf_io.py @@ -23,6 +23,6 @@ def load_file(file_path: str, *, device: Optional[TensorDeviceType] = None) -> D return tf_load_file(file_path) -@io.save_file.register(tf.Tensor) +@io.save_file.register def _(data: Dict[str, tf.Tensor], file_path: str) -> None: return tf_save_file(data, file_path) From 259e12b2f2d56a8281792f1ec6e5a416ab8429c0 Mon Sep 17 00:00:00 2001 From: darshil929 Date: Sat, 15 Mar 2025 18:23:48 +0530 Subject: [PATCH 07/17] fix exception string literals in tf_linalg.py --- nncf/tensor/functions/tf_linalg.py | 23 +++++++++++++++-------- tests/tensorflow/test_tensor.py | 1 + 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/nncf/tensor/functions/tf_linalg.py b/nncf/tensor/functions/tf_linalg.py index 35b5a472d31..1b100b9f585 100644 --- a/nncf/tensor/functions/tf_linalg.py +++ b/nncf/tensor/functions/tf_linalg.py @@ -33,7 +33,8 @@ def _( with tf.device(a.device): if ord == "nuc" and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: - raise ValueError("ord='nuc' is only supported for 2D tensors") + error_msg = "ord='nuc' is only supported for 2D tensors" + raise ValueError(error_msg) s = tf.linalg.svd(a, compute_uv=False) result = tf.reduce_sum(s, axis=-1) if keepdims: @@ -46,34 +47,40 @@ def _( if ord == -1 and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: - raise ValueError("ord=-1 is only supported for 2D tensors") + error_msg = "ord=-1 is only supported for 2D tensors" + raise ValueError(error_msg) return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) if ord == 1 and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: - raise ValueError("ord=1 is only supported for 2D tensors") + error_msg = "ord=1 is only supported for 2D tensors" + raise ValueError(error_msg) return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) if ord == -2 and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: - raise ValueError("ord=-2 is only supported for 2D tensors") + error_msg = "ord=-2 is only supported for 2D tensors" + raise ValueError(error_msg) s = tf.linalg.svd(a, compute_uv=False) return tf.reduce_min(s, axis=-1) if ord == 2 and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: - raise ValueError("ord=2 is only supported for 2D tensors") + error_msg = "ord=2 is only supported for 2D tensors" + raise ValueError(error_msg) s = tf.linalg.svd(a, compute_uv=False) return tf.reduce_max(s, axis=-1) if ord == float("inf") and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: - raise ValueError("ord=inf is only supported for 2D tensors") + error_msg = "ord=inf is only supported for 2D tensors" + raise ValueError(error_msg) return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) if ord == -float("inf") and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: - raise ValueError("ord=-inf is only supported for 2D tensors") + error_msg = "ord=-inf is only supported for 2D tensors" + raise ValueError(error_msg) return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims) @@ -130,4 +137,4 @@ def _(a: tf.Tensor, full_matrices: Optional[bool] = True) -> tf.Tensor: with tf.device(a.device): s, u, v = tf.linalg.svd(a, full_matrices=full_matrices) - return u, s, tf.transpose(v, conjugate=True) + return u, s, tf.transpose(v, conjugate=True) \ No newline at end of file diff --git a/tests/tensorflow/test_tensor.py b/tests/tensorflow/test_tensor.py index 7169d31c0b6..e5596201e0d 100644 --- a/tests/tensorflow/test_tensor.py +++ b/tests/tensorflow/test_tensor.py @@ -8,6 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import pytest import tensorflow as tf From 22a95809c933531ac7d27377eb96f964be6abc3f Mon Sep 17 00:00:00 2001 From: darshil929 Date: Sat, 15 Mar 2025 18:29:16 +0530 Subject: [PATCH 08/17] fix exception string literals in tf_numeric.py --- nncf/tensor/functions/tf_numeric.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index e1a61437ed0..ee469d1a8d8 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -388,7 +388,8 @@ def _(a: tf.Tensor, k: int = 0) -> tf.Tensor: elif rank == 2: return tf.linalg.diag_part(a, k=k) else: - raise ValueError("Input tensor must be 1D or 2D.") + error_msg = "Input tensor must be 1D or 2D." + raise ValueError(error_msg) @numeric.logical_or.register(tf.Tensor) @@ -439,14 +440,16 @@ def _(a: tf.Tensor, axis: Union[int, Tuple[int, ...], List[int]]) -> np.ndarray: axis = (axis,) if len(set(axis)) != len(axis): - raise ValueError("repeated axis") + error_msg = "repeated axis" + raise ValueError(error_msg) out_ndim = len(axis) + a.ndim norm_axis = [] for ax in axis: if ax < -out_ndim or ax >= out_ndim: - raise ValueError(f"axis {ax} is out of bounds for array of dimension {out_ndim}") + error_msg = f"axis {ax} is out of bounds for array of dimension {out_ndim}" + raise ValueError(error_msg) norm_axis.append(ax + out_ndim if ax < 0 else ax) shape_it = iter(a.shape) @@ -463,9 +466,11 @@ def _(a: tf.Tensor) -> tf.Tensor: @numeric.searchsorted.register(tf.Tensor) def _(a: tf.Tensor, v: tf.Tensor, side: str = "left", sorter: Optional[tf.Tensor] = None) -> tf.Tensor: if side not in ["right", "left"]: - raise ValueError(f"Invalid value for 'side': {side}. Expected 'right' or 'left'.") + error_msg = f"Invalid value for 'side': {side}. Expected 'right' or 'left'." + raise ValueError(error_msg) if a.ndim != 1: - raise ValueError(f"Input tensor 'a' must be 1-D. Received {a.ndim}-D tensor.") + error_msg = f"Input tensor 'a' must be 1-D. Received {a.ndim}-D tensor." + raise ValueError(error_msg) sorted_a = tf.sort(a) return tf.searchsorted(sorted_sequence=sorted_a, values=v, side=side) @@ -542,4 +547,4 @@ def tensor( device = convert_to_tf_device(device) dtype = convert_to_tf_dtype(dtype) with tf.device(device): - return tf.constant(data, dtype=dtype) + return tf.constant(data, dtype=dtype) \ No newline at end of file From 8d3b0da5577f9bbe6a6ee972d6ad79d6a0042886 Mon Sep 17 00:00:00 2001 From: darshil929 Date: Tue, 18 Mar 2025 00:41:34 +0530 Subject: [PATCH 09/17] address reviewer feedback: update file path types to Path, rollback version changes, and improve tensor device handling --- nncf/tensor/functions/tf_io.py | 5 +++-- nncf/tensor/functions/tf_numeric.py | 4 ++-- nncf/tensor/tensor.py | 2 +- nncf/version.py | 2 +- tests/cross_fw/test_templates/template_test_nncf_tensor.py | 3 +-- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/nncf/tensor/functions/tf_io.py b/nncf/tensor/functions/tf_io.py index 6ab5d7e4b0f..cf1a5c24f2b 100644 --- a/nncf/tensor/functions/tf_io.py +++ b/nncf/tensor/functions/tf_io.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from typing import Dict, Optional import tensorflow as tf @@ -19,10 +20,10 @@ from nncf.tensor.functions import io as io -def load_file(file_path: str, *, device: Optional[TensorDeviceType] = None) -> Dict[str, tf.Tensor]: +def load_file(file_path: Path, *, device: Optional[TensorDeviceType] = None) -> Dict[str, tf.Tensor]: return tf_load_file(file_path) @io.save_file.register -def _(data: Dict[str, tf.Tensor], file_path: str) -> None: +def _(data: Dict[str, tf.Tensor], file_path: Path) -> None: return tf_save_file(data, file_path) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index ee469d1a8d8..28dcf80c18b 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -293,13 +293,13 @@ def _( @numeric._binary_op_nowarn.register(tf.Tensor) def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor: with tf.device(a.device): - return operator_fn(a, b) + return tf.identity(operator_fn(a, b)) @numeric._binary_reverse_op_nowarn.register(tf.Tensor) def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor: with tf.device(a.device): - return operator_fn(b, a) + return tf.identity(operator_fn(b, a)) @numeric.clip.register(tf.Tensor) diff --git a/nncf/tensor/tensor.py b/nncf/tensor/tensor.py index 998b3dc2dfb..9ccd841b164 100644 --- a/nncf/tensor/tensor.py +++ b/nncf/tensor/tensor.py @@ -140,7 +140,7 @@ def __floordiv__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: def __rfloordiv__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: return cast(Tensor, _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv)) - def __ifloordiv__(self, other: Union[Tensor, float]) -> Tensor: + def __ifloordiv__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: self._data //= unwrap_tensor_data(other) return self diff --git a/nncf/version.py b/nncf/version.py index 26ec55fa1da..f7b5b2206e3 100644 --- a/nncf/version.py +++ b/nncf/version.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.16.0.dev0+b1af5d11dirty" +__version__ = "2.16.0" BKC_TORCH_SPEC = "==2.6.*" diff --git a/tests/cross_fw/test_templates/template_test_nncf_tensor.py b/tests/cross_fw/test_templates/template_test_nncf_tensor.py index fca455f3b18..3cb154a6a5e 100644 --- a/tests/cross_fw/test_templates/template_test_nncf_tensor.py +++ b/tests/cross_fw/test_templates/template_test_nncf_tensor.py @@ -113,8 +113,7 @@ def test_operators_tensor(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) - if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU): - assert res_nncf.device == nncf_tensor_a.device + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", OPERATOR_MAP.keys()) def test_operators_int(self, op_name): From 9a0c6d71a4081f4acf905215eb3a37872b44f3b4 Mon Sep 17 00:00:00 2001 From: darshil929 Date: Tue, 18 Mar 2025 01:22:05 +0530 Subject: [PATCH 10/17] update copyright headers to solve pre-commit issue --- nncf/tensor/functions/tf_io.py | 2 +- nncf/tensor/functions/tf_linalg.py | 2 +- nncf/tensor/functions/tf_numeric.py | 2 +- tests/tensorflow/test_tensor.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nncf/tensor/functions/tf_io.py b/nncf/tensor/functions/tf_io.py index cf1a5c24f2b..f8dc4af8e68 100644 --- a/nncf/tensor/functions/tf_io.py +++ b/nncf/tensor/functions/tf_io.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Intel Corporation +# Copyright (c) 2025 Intel Corporation # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/nncf/tensor/functions/tf_linalg.py b/nncf/tensor/functions/tf_linalg.py index 1b100b9f585..1e7e3ec14c2 100644 --- a/nncf/tensor/functions/tf_linalg.py +++ b/nncf/tensor/functions/tf_linalg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Intel Corporation +# Copyright (c) 2025 Intel Corporation # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index 28dcf80c18b..13a5143c9e5 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Intel Corporation +# Copyright (c) 2025 Intel Corporation # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/tensorflow/test_tensor.py b/tests/tensorflow/test_tensor.py index e5596201e0d..c0c97e10b91 100644 --- a/tests/tensorflow/test_tensor.py +++ b/tests/tensorflow/test_tensor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Intel Corporation +# Copyright (c) 2025 Intel Corporation # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at From 49b34e34f8086c3f5ec71e88a86c9078b2fbf595 Mon Sep 17 00:00:00 2001 From: darshil929 Date: Tue, 18 Mar 2025 18:57:25 +0530 Subject: [PATCH 11/17] update mypy workflow to include TensorFlow --- .github/workflows/mypy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index ba2dc684c71..5bc9aa44a46 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -23,7 +23,7 @@ jobs: python-version: 3.10.14 - name: Install NNCF run: | - pip install -e . torch -c constraints.txt + pip install -e . torch tensorflow -c constraints.txt - name: Install mypy run: pip install mypy==1.8.0 - name: Run mypy From f82cfe2ca1697f4dddf7c874c321d1c3b424613f Mon Sep 17 00:00:00 2001 From: darshil929 Date: Tue, 18 Mar 2025 19:04:35 +0530 Subject: [PATCH 12/17] apply automatic code formatting from ruff --- nncf/tensor/functions/tf_linalg.py | 2 +- nncf/tensor/functions/tf_numeric.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nncf/tensor/functions/tf_linalg.py b/nncf/tensor/functions/tf_linalg.py index 1e7e3ec14c2..498192923ce 100644 --- a/nncf/tensor/functions/tf_linalg.py +++ b/nncf/tensor/functions/tf_linalg.py @@ -137,4 +137,4 @@ def _(a: tf.Tensor, full_matrices: Optional[bool] = True) -> tf.Tensor: with tf.device(a.device): s, u, v = tf.linalg.svd(a, full_matrices=full_matrices) - return u, s, tf.transpose(v, conjugate=True) \ No newline at end of file + return u, s, tf.transpose(v, conjugate=True) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index 13a5143c9e5..de6807b2fba 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -547,4 +547,4 @@ def tensor( device = convert_to_tf_device(device) dtype = convert_to_tf_dtype(dtype) with tf.device(device): - return tf.constant(data, dtype=dtype) \ No newline at end of file + return tf.constant(data, dtype=dtype) From e509fa9582c3b99739bb1c263f39c25c25408214 Mon Sep 17 00:00:00 2001 From: darshil929 Date: Tue, 18 Mar 2025 20:33:52 +0530 Subject: [PATCH 13/17] fix tensorflow tensor registration in tf_linalg.py function decorators --- nncf/tensor/functions/tf_linalg.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/nncf/tensor/functions/tf_linalg.py b/nncf/tensor/functions/tf_linalg.py index 498192923ce..b7c52051f09 100644 --- a/nncf/tensor/functions/tf_linalg.py +++ b/nncf/tensor/functions/tf_linalg.py @@ -17,7 +17,7 @@ from nncf.tensor.functions import linalg -@linalg.norm.register(tf.Tensor) +@linalg.norm.register def _( a: tf.Tensor, ord: Optional[Union[str, float, int]] = None, @@ -86,7 +86,7 @@ def _( return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims) -@linalg.cholesky.register(tf.Tensor) +@linalg.cholesky.register def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor: with tf.device(a.device): cholesky = tf.linalg.cholesky(a) @@ -97,7 +97,7 @@ def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor: return cholesky -@linalg.cholesky_inverse.register(tf.Tensor) +@linalg.cholesky_inverse.register def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor: with tf.device(a.device): if upper: @@ -109,19 +109,19 @@ def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor: return tf.linalg.cholesky_solve(a, eye) -@linalg.inv.register(tf.Tensor) +@linalg.inv.register def _(a: tf.Tensor) -> tf.Tensor: with tf.device(a.device): return tf.linalg.inv(a) -@linalg.pinv.register(tf.Tensor) +@linalg.pinv.register def _(a: tf.Tensor) -> tf.Tensor: with tf.device(a.device): return tf.linalg.pinv(a) -@linalg.lstsq.register(tf.Tensor) +@linalg.lstsq.register def _(a: tf.Tensor, b: tf.Tensor, driver: Optional[str] = None) -> tf.Tensor: with tf.device(a.device): if driver is not None: @@ -132,7 +132,7 @@ def _(a: tf.Tensor, b: tf.Tensor, driver: Optional[str] = None) -> tf.Tensor: return tf.linalg.lstsq(a, b) -@linalg.svd.register(tf.Tensor) +@linalg.svd.register def _(a: tf.Tensor, full_matrices: Optional[bool] = True) -> tf.Tensor: with tf.device(a.device): s, u, v = tf.linalg.svd(a, full_matrices=full_matrices) From a2b7875484d253982645dc4991110503663ba585 Mon Sep 17 00:00:00 2001 From: darshil929 Date: Tue, 18 Mar 2025 20:41:08 +0530 Subject: [PATCH 14/17] fix tensorflow tensor registration in tf_numeric.py function decorators --- nncf/tensor/functions/tf_numeric.py | 108 ++++++++++++++-------------- 1 file changed, 54 insertions(+), 54 deletions(-) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index de6807b2fba..e3d165be5e1 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -46,7 +46,7 @@ def convert_to_tf_dtype(dtype: TensorDataType) -> tf.DType: return DTYPE_MAP[dtype] if dtype is not None else None -@numeric.device.register(tf.Tensor) +@numeric.device.register def _(a: tf.Tensor) -> TensorDeviceType: if "CPU" in a.device: return DEVICE_MAP_REV["CPU"] @@ -54,59 +54,59 @@ def _(a: tf.Tensor) -> TensorDeviceType: return DEVICE_MAP_REV["GPU"] -@numeric.backend.register(tf.Tensor) +@numeric.backend.register def _(a: tf.Tensor) -> TensorBackend: return TensorBackend.tf -@numeric.squeeze.register(tf.Tensor) +@numeric.squeeze.register def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: with tf.device(a.device): return tf.squeeze(a, axis) -@numeric.flatten.register(tf.Tensor) +@numeric.flatten.register def _(a: tf.Tensor) -> tf.Tensor: with tf.device(a.device): return tf.reshape(a, [-1]) -@numeric.max.register(tf.Tensor) +@numeric.max.register def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor: with tf.device(a.device): return tf.reduce_max(a, axis=axis, keepdims=keepdims) -@numeric.min.register(tf.Tensor) +@numeric.min.register def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor: with tf.device(a.device): return tf.reduce_min(a, axis=axis, keepdims=keepdims) -@numeric.abs.register(tf.Tensor) +@numeric.abs.register def _(a: tf.Tensor) -> tf.Tensor: with tf.device(a.device): return tf.abs(a) -@numeric.astype.register(tf.Tensor) +@numeric.astype.register def _(a: tf.Tensor, dtype: TensorDataType) -> tf.Tensor: with tf.device(a.device): return tf.cast(a, DTYPE_MAP[dtype]) -@numeric.dtype.register(tf.Tensor) +@numeric.dtype.register def _(a: tf.Tensor) -> TensorDataType: return DTYPE_MAP_REV[a.dtype] -@numeric.reshape.register(tf.Tensor) +@numeric.reshape.register def _(a: tf.Tensor, shape: Tuple[int, ...]) -> tf.Tensor: with tf.device(a.device): return tf.reshape(a, shape) -@numeric.all.register(tf.Tensor) +@numeric.all.register def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: with tf.device(a.device): if axis is None: @@ -114,7 +114,7 @@ def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Te return tf.reduce_all(a, axis=axis) -@numeric.allclose.register(tf.Tensor) +@numeric.allclose.register def _( a: tf.Tensor, b: Union[tf.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False ) -> bool: @@ -122,7 +122,7 @@ def _( return bool(tf.experimental.numpy.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) -@numeric.any.register(tf.Tensor) +@numeric.any.register def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: with tf.device(a.device): if axis is None: @@ -130,18 +130,18 @@ def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Te return tf.reduce_any(a, axis=axis) -@numeric.count_nonzero.register(tf.Tensor) +@numeric.count_nonzero.register def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: with tf.device(a.device): return tf.math.count_nonzero(a, axis=axis) -@numeric.isempty.register(tf.Tensor) +@numeric.isempty.register def _(a: tf.Tensor) -> bool: return bool(tf.equal(tf.size(a), 0)) -@numeric.isclose.register(tf.Tensor) +@numeric.isclose.register def _( a: tf.Tensor, b: Union[tf.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False ) -> tf.Tensor: @@ -149,49 +149,49 @@ def _( return tf.experimental.numpy.isclose(a, b, atol=atol, rtol=rtol, equal_nan=equal_nan) -@numeric.maximum.register(tf.Tensor) +@numeric.maximum.register def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: with tf.device(x1.device): return tf.maximum(x1, x2) -@numeric.minimum.register(tf.Tensor) +@numeric.minimum.register def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: with tf.device(x1.device): return tf.minimum(x1, x2) -@numeric.ones_like.register(tf.Tensor) +@numeric.ones_like.register def _(a: tf.Tensor) -> tf.Tensor: with tf.device(a.device): return tf.ones_like(a) -@numeric.where.register(tf.Tensor) +@numeric.where.register def _(condition: tf.Tensor, x: Union[tf.Tensor, float, bool], y: Union[tf.Tensor, float, bool]) -> tf.Tensor: with tf.device(condition.device): return tf.where(condition, x, y) -@numeric.zeros_like.register(tf.Tensor) +@numeric.zeros_like.register def _(a: tf.Tensor) -> tf.Tensor: with tf.device(a.device): return tf.zeros_like(a) -@numeric.stack.register(tf.Tensor) +@numeric.stack.register def _(x: List[tf.Tensor], axis: int = 0) -> tf.Tensor: with tf.device(x[0].device): return tf.stack(x, axis=axis) -@numeric.concatenate.register(tf.Tensor) +@numeric.concatenate.register def _(x: List[tf.Tensor], axis: int = 0) -> tf.Tensor: with tf.device(x[0].device): return tf.concat(x, axis=axis) -@numeric.unstack.register(tf.Tensor) +@numeric.unstack.register def _(x: tf.Tensor, axis: int = 0) -> List[tf.Tensor]: with tf.device(x.device): if not list(x.shape): @@ -199,13 +199,13 @@ def _(x: tf.Tensor, axis: int = 0) -> List[tf.Tensor]: return tf.unstack(x, axis=axis) -@numeric.moveaxis.register(tf.Tensor) +@numeric.moveaxis.register def _(a: tf.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> tf.Tensor: with tf.device(a.device): return tf.experimental.numpy.moveaxis(a, source, destination) -@numeric.mean.register(tf.Tensor) +@numeric.mean.register def _( a: tf.Tensor, axis: Union[int, Tuple[int, ...]] = None, @@ -217,7 +217,7 @@ def _( return tf.reduce_mean(a, axis=axis, keepdims=keepdims) -@numeric.median.register(tf.Tensor) +@numeric.median.register def _( a: tf.Tensor, axis: Union[int, Tuple[int, ...]] = None, @@ -250,7 +250,7 @@ def _( return median -@numeric.round.register(tf.Tensor) +@numeric.round.register def _(a: tf.Tensor, decimals: int = 0) -> tf.Tensor: scale_factor = 10**decimals scaled_tensor = a * scale_factor @@ -259,13 +259,13 @@ def _(a: tf.Tensor, decimals: int = 0) -> tf.Tensor: return rounded_tensor / scale_factor -@numeric.power.register(tf.Tensor) +@numeric.power.register def _(a: tf.Tensor, exponent: Union[tf.Tensor, float]) -> tf.Tensor: with tf.device(a.device): return tf.pow(a, exponent) -@numeric.quantile.register(tf.Tensor) +@numeric.quantile.register def quantile( a: tf.Tensor, q: Union[float, List[float]], @@ -278,7 +278,7 @@ def quantile( return tf.constant(quantile_np) -@numeric.percentile.register(tf.Tensor) +@numeric.percentile.register def _( a: tf.Tensor, q: Union[float, List[float]], @@ -290,54 +290,54 @@ def _( return numeric.quantile(a, q=q, axis=axis, keepdims=keepdims) -@numeric._binary_op_nowarn.register(tf.Tensor) +@numeric._binary_op_nowarn.register def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor: with tf.device(a.device): return tf.identity(operator_fn(a, b)) -@numeric._binary_reverse_op_nowarn.register(tf.Tensor) +@numeric._binary_reverse_op_nowarn.register def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor: with tf.device(a.device): return tf.identity(operator_fn(b, a)) -@numeric.clip.register(tf.Tensor) +@numeric.clip.register def _(a: tf.Tensor, a_min: Union[tf.Tensor, float], a_max: Union[tf.Tensor, float]) -> tf.Tensor: with tf.device(a.device): return tf.clip_by_value(a, a_min, a_max) -@numeric.finfo.register(tf.Tensor) +@numeric.finfo.register def _(a: tf.Tensor) -> TypeInfo: ti = tf.experimental.numpy.finfo(a.dtype) return TypeInfo(ti.eps, ti.max, ti.min) -@numeric.as_tensor_like.register(tf.Tensor) +@numeric.as_tensor_like.register def _(a: tf.Tensor, data: Any) -> tf.Tensor: with tf.device(a.device): return tf.convert_to_tensor(data) -@numeric.item.register(tf.Tensor) +@numeric.item.register def _(a: tf.Tensor) -> Union[int, float, bool]: return a.numpy().item() -@numeric.sum.register(tf.Tensor) +@numeric.sum.register def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor: with tf.device(a.device): return tf.reduce_sum(a, axis=axis, keepdims=keepdims) -@numeric.multiply.register(tf.Tensor) +@numeric.multiply.register def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: with tf.device(x1.device): return tf.multiply(x1, x2) -@numeric.var.register(tf.Tensor) +@numeric.var.register def _( a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ddof: int = 0 ) -> tf.Tensor: @@ -349,37 +349,37 @@ def _( return tf_var -@numeric.size.register(tf.Tensor) +@numeric.size.register def _(a: tf.Tensor) -> int: return tf.size(a) -@numeric.matmul.register(tf.Tensor) +@numeric.matmul.register def _(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: with tf.device(x1.device): return tf.matmul(x1, x2) -@numeric.unsqueeze.register(tf.Tensor) +@numeric.unsqueeze.register def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: with tf.device(a.device): return tf.expand_dims(a, axis=axis) -@numeric.transpose.register(tf.Tensor) +@numeric.transpose.register def _(a: tf.Tensor, axes: Optional[Tuple[int, ...]] = None) -> tf.Tensor: with tf.device(a.device): return tf.transpose(a, perm=axes) -@numeric.argsort.register(tf.Tensor) +@numeric.argsort.register def _(a: tf.Tensor, axis: int = -1, descending=False, stable=False) -> tf.Tensor: with tf.device(a.device): direction = "DESCENDING" if descending else "ASCENDING" return tf.argsort(a, axis=axis, direction=direction, stable=stable) -@numeric.diag.register(tf.Tensor) +@numeric.diag.register def _(a: tf.Tensor, k: int = 0) -> tf.Tensor: with tf.device(a.device): rank = tf.rank(a) @@ -392,13 +392,13 @@ def _(a: tf.Tensor, k: int = 0) -> tf.Tensor: raise ValueError(error_msg) -@numeric.logical_or.register(tf.Tensor) +@numeric.logical_or.register def _(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: with tf.device(x1.device): return tf.logical_or(x1, x2) -@numeric.masked_mean.register(tf.Tensor) +@numeric.masked_mean.register def _( x: tf.Tensor, mask: Optional[tf.Tensor], axis: Union[int, Tuple[int, ...], List[int]], keepdims=False ) -> tf.Tensor: @@ -416,7 +416,7 @@ def _( return ret -@numeric.masked_median.register(tf.Tensor) +@numeric.masked_median.register def _( x: tf.Tensor, mask: Optional[tf.Tensor], axis: Union[int, Tuple[int, ...], List[int]], keepdims=False ) -> tf.Tensor: @@ -434,7 +434,7 @@ def _( return ret -@numeric.expand_dims.register(tf.Tensor) +@numeric.expand_dims.register def _(a: tf.Tensor, axis: Union[int, Tuple[int, ...], List[int]]) -> np.ndarray: if type(axis) not in (tuple, list): axis = (axis,) @@ -457,13 +457,13 @@ def _(a: tf.Tensor, axis: Union[int, Tuple[int, ...], List[int]]) -> np.ndarray: return tf.reshape(a, shape) -@numeric.clone.register(tf.Tensor) +@numeric.clone.register def _(a: tf.Tensor) -> tf.Tensor: with tf.device(a.device): return tf.identity(a) -@numeric.searchsorted.register(tf.Tensor) +@numeric.searchsorted.register def _(a: tf.Tensor, v: tf.Tensor, side: str = "left", sorter: Optional[tf.Tensor] = None) -> tf.Tensor: if side not in ["right", "left"]: error_msg = f"Invalid value for 'side': {side}. Expected 'right' or 'left'." @@ -526,13 +526,13 @@ def from_numpy(ndarray: np.ndarray) -> tf.Tensor: return tf.constant(ndarray) -@numeric.log2.register(tf.Tensor) +@numeric.log2.register def _(a: tf.Tensor) -> tf.Tensor: with tf.device(a.device): return tf.math.log(a) / tf.math.log(2.0) -@numeric.ceil.register(tf.Tensor) +@numeric.ceil.register def _(a: tf.Tensor) -> tf.Tensor: with tf.device(a.device): return tf.math.ceil(a) From 10a38ada2790ae377eb2bea4d9ea89a91cb8aa1a Mon Sep 17 00:00:00 2001 From: darshil929 Date: Wed, 19 Mar 2025 01:33:04 +0530 Subject: [PATCH 15/17] fix type annotations and error handling to resolve mypy and tensorflow pytest CI checks --- nncf/tensor/functions/numpy_io.py | 2 +- nncf/tensor/functions/numpy_numeric.py | 2 +- nncf/tensor/functions/tf_linalg.py | 4 +- nncf/tensor/functions/tf_numeric.py | 78 ++++++++++++++------------ 4 files changed, 45 insertions(+), 41 deletions(-) diff --git a/nncf/tensor/functions/numpy_io.py b/nncf/tensor/functions/numpy_io.py index 2c9464f1a54..cf5badb4cb7 100644 --- a/nncf/tensor/functions/numpy_io.py +++ b/nncf/tensor/functions/numpy_io.py @@ -23,7 +23,7 @@ from nncf.tensor.functions.numpy_numeric import validate_device T_NUMPY_ARRAY = NDArray[Any] -T_NUMPY = Union[T_NUMPY_ARRAY, np.generic] # type: ignore [type-arg] +T_NUMPY = Union[T_NUMPY_ARRAY, np.generic] def load_file(file_path: str, *, device: Optional[TensorDeviceType] = None) -> Dict[str, T_NUMPY_ARRAY]: diff --git a/nncf/tensor/functions/numpy_numeric.py b/nncf/tensor/functions/numpy_numeric.py index f52f4a69470..5ada04f9467 100644 --- a/nncf/tensor/functions/numpy_numeric.py +++ b/nncf/tensor/functions/numpy_numeric.py @@ -27,7 +27,7 @@ from nncf.tensor.tensor import TTensor T_NUMPY_ARRAY = NDArray[Any] -T_NUMPY = Union[T_NUMPY_ARRAY, np.generic] # type: ignore [type-arg] +T_NUMPY = Union[T_NUMPY_ARRAY, np.generic] DTYPE_MAP: Dict[TensorDataType, DTypeLike] = { TensorDataType.float16: np.dtype(np.float16), diff --git a/nncf/tensor/functions/tf_linalg.py b/nncf/tensor/functions/tf_linalg.py index b7c52051f09..6f594e414a2 100644 --- a/nncf/tensor/functions/tf_linalg.py +++ b/nncf/tensor/functions/tf_linalg.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union import tensorflow as tf @@ -20,7 +20,7 @@ @linalg.norm.register def _( a: tf.Tensor, - ord: Optional[Union[str, float, int]] = None, + ord: Union[Literal["fro", "nuc"], float, None] = None, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> tf.Tensor: diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index e3d165be5e1..b298bdab685 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -12,6 +12,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import numpy as np +import numpy.typing as npt import tensorflow as tf from nncf.tensor import TensorDataType @@ -38,11 +39,11 @@ DEVICE_MAP_REV = {v: k for k, v in DEVICE_MAP.items()} -def convert_to_tf_device(device: TensorDeviceType) -> str: +def convert_to_tf_device(device: Optional[TensorDeviceType]) -> Optional[str]: return DEVICE_MAP[device] if device is not None else None -def convert_to_tf_dtype(dtype: TensorDataType) -> tf.DType: +def convert_to_tf_dtype(dtype: Optional[TensorDataType]) -> Optional[tf.DType]: return DTYPE_MAP[dtype] if dtype is not None else None @@ -52,6 +53,7 @@ def _(a: tf.Tensor) -> TensorDeviceType: return DEVICE_MAP_REV["CPU"] if "GPU" in a.device: return DEVICE_MAP_REV["GPU"] + return TensorDeviceType.CPU @numeric.backend.register @@ -243,7 +245,7 @@ def _( else: median = tf.gather(top_k, indices=[k - 1], axis=-1) median = tf.squeeze(median, axis=-1) - if keepdims: + if keepdims and axis is not None: for axe in sorted(axis, key=lambda x: abs(x)): median = tf.expand_dims(median, axe) @@ -282,22 +284,24 @@ def quantile( def _( a: tf.Tensor, q: Union[float, List[float]], - axis: Union[int, Tuple[int, ...], List[int]], + axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, -) -> List[Union[tf.Tensor, np.generic]]: +) -> tf.Tensor: with tf.device(a.device): q = [x / 100 for x in q] if isinstance(q, (list, tuple)) else q / 100 + if isinstance(axis, list): + axis = tuple(axis) return numeric.quantile(a, q=q, axis=axis, keepdims=keepdims) @numeric._binary_op_nowarn.register -def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor: +def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable[[tf.Tensor, Union[tf.Tensor, float]], tf.Tensor]) -> tf.Tensor: with tf.device(a.device): return tf.identity(operator_fn(a, b)) @numeric._binary_reverse_op_nowarn.register -def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor: +def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable[[Union[tf.Tensor, float], tf.Tensor], tf.Tensor]) -> tf.Tensor: with tf.device(a.device): return tf.identity(operator_fn(b, a)) @@ -435,18 +439,24 @@ def _( @numeric.expand_dims.register -def _(a: tf.Tensor, axis: Union[int, Tuple[int, ...], List[int]]) -> np.ndarray: - if type(axis) not in (tuple, list): - axis = (axis,) - - if len(set(axis)) != len(axis): +def _(a: tf.Tensor, axis: Union[int, Tuple[int, ...], List[int]]) -> tf.Tensor: + if isinstance(axis, int): + axes_tuple: Tuple[int, ...] = (axis,) + elif isinstance(axis, list): + axes_tuple: Tuple[int, ...] = tuple(axis) + elif isinstance(axis, tuple): + axes_tuple: Tuple[int, ...] = axis + else: + raise TypeError(f"axis must be int, tuple, or list, got {type(axis)}") + + if len(set(axes_tuple)) != len(axes_tuple): error_msg = "repeated axis" raise ValueError(error_msg) - out_ndim = len(axis) + a.ndim + out_ndim = len(axes_tuple) + a.ndim norm_axis = [] - for ax in axis: + for ax in axes_tuple: if ax < -out_ndim or ax >= out_ndim: error_msg = f"axis {ax} is out of bounds for array of dimension {out_ndim}" raise ValueError(error_msg) @@ -481,12 +491,10 @@ def zeros( dtype: Optional[TensorDataType] = None, device: Optional[TensorDeviceType] = None, ) -> tf.Tensor: - if dtype is not None: - dtype = DTYPE_MAP[dtype] - if device is not None: - device = DEVICE_MAP[device] - with tf.device(device): - return tf.zeros(shape, dtype=dtype) + tf_dtype = DTYPE_MAP[dtype] if dtype is not None else None + tf_device = DEVICE_MAP[device] if device is not None else None + with tf.device(tf_device): + return tf.zeros(shape, dtype=tf_dtype) def eye( @@ -496,13 +504,11 @@ def eye( dtype: Optional[TensorDataType] = None, device: Optional[TensorDeviceType] = None, ) -> tf.Tensor: - if dtype is not None: - dtype = DTYPE_MAP[dtype] - if device is not None: - device = DEVICE_MAP[device] + tf_dtype = DTYPE_MAP[dtype] if dtype is not None else None + tf_device = DEVICE_MAP[device] if device is not None else None p_args = (n,) if m is None else (n, m) - with tf.device(device): - return tf.eye(*p_args, dtype=dtype) + with tf.device(tf_device): + return tf.eye(*p_args, dtype=tf_dtype) def arange( @@ -513,15 +519,13 @@ def arange( dtype: Optional[TensorDataType] = None, device: Optional[TensorDeviceType] = None, ) -> tf.Tensor: - if dtype is not None: - dtype = DTYPE_MAP[dtype] - if device is not None: - device = DEVICE_MAP[device] - with tf.device(device): - return tf.range(start, end, step, dtype=dtype) + tf_dtype = DTYPE_MAP[dtype] if dtype is not None else None + tf_device = DEVICE_MAP[device] if device is not None else None + with tf.device(tf_device): + return tf.range(start, end, step, dtype=tf_dtype) -def from_numpy(ndarray: np.ndarray) -> tf.Tensor: +def from_numpy(ndarray: npt.NDArray[Any]) -> tf.Tensor: with tf.device("CPU"): return tf.constant(ndarray) @@ -544,7 +548,7 @@ def tensor( dtype: Optional[TensorDataType] = None, device: Optional[TensorDeviceType] = None, ) -> tf.Tensor: - device = convert_to_tf_device(device) - dtype = convert_to_tf_dtype(dtype) - with tf.device(device): - return tf.constant(data, dtype=dtype) + tf_device = convert_to_tf_device(device) + tf_dtype = convert_to_tf_dtype(dtype) + with tf.device(tf_device): + return tf.constant(data, dtype=tf_dtype) From 5768fb64d6a84171ffed83a84ec574a5f2062cbe Mon Sep 17 00:00:00 2001 From: darshil929 Date: Sat, 22 Mar 2025 02:06:44 +0530 Subject: [PATCH 16/17] fix tensorflow tensor implementation with norm fixes and device preservation --- nncf/tensor/functions/numeric.py | 32 ++++++++ nncf/tensor/functions/tf_io.py | 16 +++- nncf/tensor/functions/tf_linalg.py | 57 +++++++++++--- nncf/tensor/functions/tf_numeric.py | 114 +++++++++++++++++++++------- nncf/tensor/tensor.py | 20 +++++ 5 files changed, 202 insertions(+), 37 deletions(-) diff --git a/nncf/tensor/functions/numeric.py b/nncf/tensor/functions/numeric.py index d1dc51577b5..273406a6bca 100644 --- a/nncf/tensor/functions/numeric.py +++ b/nncf/tensor/functions/numeric.py @@ -105,6 +105,16 @@ def abs(a: Tensor) -> Tensor: """ +@tensor_dispatcher +def neg(a: Tensor) -> Tensor: + """ + Numerical negative, element-wise. + + :param a: The input tensor. + :return: A tensor containing the negative value of each element in a. + """ + + @tensor_dispatcher def astype(a: Tensor, dtype: TensorDataType) -> Tensor: """ @@ -493,6 +503,28 @@ def sum(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: """ +@tensor_dispatcher +def add(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: + """ + Add two tensors element-wise. + + :param x1: The first input tensor. + :param x2: The second input tensor or number. + :return: The sum of x1 and x2, element-wise. + """ + + +@tensor_dispatcher +def subtract(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: + """ + Subtract two tensors element-wise. + + :param x1: The first input tensor. + :param x2: The second input tensor or number. + :return: The result of x1 - x2, element-wise. + """ + + @tensor_dispatcher def multiply(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: """ diff --git a/nncf/tensor/functions/tf_io.py b/nncf/tensor/functions/tf_io.py index f8dc4af8e68..d4dbdcaa3b3 100644 --- a/nncf/tensor/functions/tf_io.py +++ b/nncf/tensor/functions/tf_io.py @@ -18,12 +18,26 @@ from nncf.tensor import TensorDeviceType from nncf.tensor.functions import io as io +from nncf.tensor.functions.tf_numeric import DEVICE_MAP def load_file(file_path: Path, *, device: Optional[TensorDeviceType] = None) -> Dict[str, tf.Tensor]: - return tf_load_file(file_path) + loaded_tensors = tf_load_file(file_path) + + if device is not None: + device_str = DEVICE_MAP[device] + with tf.device(device_str): + loaded_tensors = {k: tf.identity(v) for k, v in loaded_tensors.items()} + + return loaded_tensors @io.save_file.register def _(data: Dict[str, tf.Tensor], file_path: Path) -> None: + if file_path.is_symlink(): + from nncf.errors import ValidationError + + error_msg = "Cannot save tensor to a symbolic link" + raise ValidationError(error_msg) + return tf_save_file(data, file_path) diff --git a/nncf/tensor/functions/tf_linalg.py b/nncf/tensor/functions/tf_linalg.py index 6f594e414a2..2248701b68f 100644 --- a/nncf/tensor/functions/tf_linalg.py +++ b/nncf/tensor/functions/tf_linalg.py @@ -24,9 +24,14 @@ def _( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> tf.Tensor: - if ord is None: - ord = "euclidean" rank = tf.rank(a) + + if ord is None: + if axis is None and rank == 2: + ord = "fro" + else: + ord = 2 + if rank == 2 and axis is None: axis = (0, 1) @@ -49,41 +54,75 @@ def _( if rank != 2: error_msg = "ord=-1 is only supported for 2D tensors" raise ValueError(error_msg) - return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) + result = tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) + if keepdims: + result = tf.reshape(result, [1, 1]) + return result if ord == 1 and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: error_msg = "ord=1 is only supported for 2D tensors" raise ValueError(error_msg) - return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) + result = tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) + if keepdims: + result = tf.reshape(result, [1, 1]) + return result if ord == -2 and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: error_msg = "ord=-2 is only supported for 2D tensors" raise ValueError(error_msg) s = tf.linalg.svd(a, compute_uv=False) - return tf.reduce_min(s, axis=-1) + result = tf.reduce_min(s, axis=-1) + if keepdims: + result = tf.reshape(result, [1, 1]) + return result if ord == 2 and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: error_msg = "ord=2 is only supported for 2D tensors" raise ValueError(error_msg) s = tf.linalg.svd(a, compute_uv=False) - return tf.reduce_max(s, axis=-1) + result = tf.reduce_max(s, axis=-1) + if keepdims: + result = tf.reshape(result, [1, 1]) + return result if ord == float("inf") and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: error_msg = "ord=inf is only supported for 2D tensors" raise ValueError(error_msg) - return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) + result = tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) + if keepdims: + result = tf.reshape(result, [1, 1]) + return result if ord == -float("inf") and isinstance(axis, tuple) and len(axis) != 1: if rank != 2: error_msg = "ord=-inf is only supported for 2D tensors" raise ValueError(error_msg) - return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) + result = tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) + if keepdims: + result = tf.reshape(result, [1, 1]) + return result - return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims) + try: + return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims) + except (TypeError, ValueError): + if axis is not None: + if ord == 2: + squared = tf.square(a) + sum_squares = tf.reduce_sum(squared, axis=axis, keepdims=keepdims) + return tf.sqrt(sum_squares) + elif ord == 1: + return tf.reduce_sum(tf.abs(a), axis=axis, keepdims=keepdims) + elif ord == float("inf"): + return tf.reduce_max(tf.abs(a), axis=axis, keepdims=keepdims) + elif ord == -float("inf"): + return tf.reduce_min(tf.abs(a), axis=axis, keepdims=keepdims) + + error_msg = f"Unsupported combination of ord={ord} and axis={axis}" + raise ValueError(error_msg) @linalg.cholesky.register diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index b298bdab685..82bfecbed6d 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Literal, Optional, Sequence, Tuple, Union import numpy as np import numpy.typing as npt @@ -103,7 +103,7 @@ def _(a: tf.Tensor) -> TensorDataType: @numeric.reshape.register -def _(a: tf.Tensor, shape: Tuple[int, ...]) -> tf.Tensor: +def reshape(a: tf.Tensor, shape: Union[int, Tuple[int, ...]]) -> tf.Tensor: with tf.device(a.device): return tf.reshape(a, shape) @@ -210,7 +210,7 @@ def _(a: tf.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, @numeric.mean.register def _( a: tf.Tensor, - axis: Union[int, Tuple[int, ...]] = None, + axis: Optional[Union[Tuple[int, ...], int]] = None, keepdims: bool = False, dtype: Optional[TensorDataType] = None, ) -> tf.Tensor: @@ -222,7 +222,7 @@ def _( @numeric.median.register def _( a: tf.Tensor, - axis: Union[int, Tuple[int, ...]] = None, + axis: Optional[Union[Tuple[int, ...], int]] = None, keepdims: bool = False, ) -> tf.Tensor: with tf.device(a.device): @@ -264,14 +264,21 @@ def _(a: tf.Tensor, decimals: int = 0) -> tf.Tensor: @numeric.power.register def _(a: tf.Tensor, exponent: Union[tf.Tensor, float]) -> tf.Tensor: with tf.device(a.device): - return tf.pow(a, exponent) + if not isinstance(exponent, tf.Tensor): + exponent_tensor = tf.convert_to_tensor(exponent, dtype=a.dtype) + else: + with tf.device(a.device): + exponent_tensor = tf.identity(exponent) + + result = tf.pow(a, exponent_tensor) + return tf.identity(result) @numeric.quantile.register def quantile( a: tf.Tensor, q: Union[float, List[float]], - axis: Optional[Union[int, Tuple[int]]] = None, + axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> tf.Tensor: a_np = a.numpy() @@ -281,29 +288,69 @@ def quantile( @numeric.percentile.register -def _( +def percentile( a: tf.Tensor, q: Union[float, List[float]], - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: Optional[Union[Tuple[int, ...], int]], keepdims: bool = False, ) -> tf.Tensor: with tf.device(a.device): q = [x / 100 for x in q] if isinstance(q, (list, tuple)) else q / 100 if isinstance(axis, list): axis = tuple(axis) - return numeric.quantile(a, q=q, axis=axis, keepdims=keepdims) + return quantile(a, q=q, axis=axis, keepdims=keepdims) @numeric._binary_op_nowarn.register -def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable[[tf.Tensor, Union[tf.Tensor, float]], tf.Tensor]) -> tf.Tensor: +def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable[..., Any]) -> tf.Tensor: with tf.device(a.device): - return tf.identity(operator_fn(a, b)) + if not isinstance(b, tf.Tensor) and isinstance(b, (int, float)): + b = tf.convert_to_tensor(b, dtype=a.dtype) + result = operator_fn(a, b) + return tf.identity(result) @numeric._binary_reverse_op_nowarn.register -def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable[[Union[tf.Tensor, float], tf.Tensor], tf.Tensor]) -> tf.Tensor: +def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable[..., Any]) -> tf.Tensor: with tf.device(a.device): - return tf.identity(operator_fn(b, a)) + if not isinstance(b, tf.Tensor) and isinstance(b, (int, float)): + b = tf.convert_to_tensor(b, dtype=a.dtype) + result = operator_fn(b, a) + return tf.identity(result) + + +@numeric.add.register +def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(x1.device): + if not isinstance(x2, tf.Tensor): + x2 = tf.convert_to_tensor(x2, dtype=x1.dtype) + result = tf.add(x1, x2) + return tf.identity(result) + + +@numeric.subtract.register +def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(x1.device): + if not isinstance(x2, tf.Tensor): + x2 = tf.convert_to_tensor(x2, dtype=x1.dtype) + result = tf.subtract(x1, x2) + return tf.identity(result) + + +@numeric.multiply.register +def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(x1.device): + if not isinstance(x2, tf.Tensor): + x2 = tf.convert_to_tensor(x2, dtype=x1.dtype) + result = tf.multiply(x1, x2) + return tf.identity(result) + + +@numeric.neg.register +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + result = tf.negative(a) + return tf.identity(result) @numeric.clip.register @@ -365,7 +412,7 @@ def _(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: @numeric.unsqueeze.register -def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: +def unsqueeze(a: tf.Tensor, axis: int) -> tf.Tensor: with tf.device(a.device): return tf.expand_dims(a, axis=axis) @@ -377,7 +424,7 @@ def _(a: tf.Tensor, axes: Optional[Tuple[int, ...]] = None) -> tf.Tensor: @numeric.argsort.register -def _(a: tf.Tensor, axis: int = -1, descending=False, stable=False) -> tf.Tensor: +def argsort(a: tf.Tensor, axis: int = -1, descending: bool = False, stable: bool = False) -> tf.Tensor: with tf.device(a.device): direction = "DESCENDING" if descending else "ASCENDING" return tf.argsort(a, axis=axis, direction=direction, stable=stable) @@ -403,9 +450,12 @@ def _(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: @numeric.masked_mean.register -def _( - x: tf.Tensor, mask: Optional[tf.Tensor], axis: Union[int, Tuple[int, ...], List[int]], keepdims=False +def masked_mean( + x: tf.Tensor, mask: Optional[tf.Tensor], axis: Optional[Union[int, Tuple[int, ...]]], keepdims: bool = False ) -> tf.Tensor: + if isinstance(axis, list): + axis = tuple(axis) + with tf.device(x.device): if mask is None: return tf.reduce_mean(x, axis=axis, keepdims=keepdims) @@ -421,12 +471,15 @@ def _( @numeric.masked_median.register -def _( - x: tf.Tensor, mask: Optional[tf.Tensor], axis: Union[int, Tuple[int, ...], List[int]], keepdims=False +def masked_median( + x: tf.Tensor, mask: Optional[tf.Tensor], axis: Optional[Union[int, Tuple[int, ...]]], keepdims: bool = False ) -> tf.Tensor: if mask is None: return numeric.median(x, axis=axis, keepdims=keepdims) + if isinstance(axis, list): + axis = tuple(axis) + masked_x = tf.where(mask, np.nan, x) np_masked_x = masked_x.numpy() np_masked_median = np.nanquantile(np_masked_x, 0.5, axis=axis, keepdims=keepdims) @@ -439,16 +492,16 @@ def _( @numeric.expand_dims.register -def _(a: tf.Tensor, axis: Union[int, Tuple[int, ...], List[int]]) -> tf.Tensor: +def expand_dims(a: tf.Tensor, axis: Union[int, Tuple[int, ...]]) -> tf.Tensor: + axes_tuple: Tuple[int, ...] if isinstance(axis, int): - axes_tuple: Tuple[int, ...] = (axis,) - elif isinstance(axis, list): - axes_tuple: Tuple[int, ...] = tuple(axis) + axes_tuple = (axis,) elif isinstance(axis, tuple): - axes_tuple: Tuple[int, ...] = axis + axes_tuple = axis else: - raise TypeError(f"axis must be int, tuple, or list, got {type(axis)}") - + error_msg = f"axis must be int or tuple, got {type(axis)}" + raise TypeError(error_msg) + if len(set(axes_tuple)) != len(axes_tuple): error_msg = "repeated axis" raise ValueError(error_msg) @@ -474,7 +527,9 @@ def _(a: tf.Tensor) -> tf.Tensor: @numeric.searchsorted.register -def _(a: tf.Tensor, v: tf.Tensor, side: str = "left", sorter: Optional[tf.Tensor] = None) -> tf.Tensor: +def searchsorted( + a: tf.Tensor, v: tf.Tensor, side: Literal["left", "right"] = "left", sorter: Optional[tf.Tensor] = None +) -> tf.Tensor: if side not in ["right", "left"]: error_msg = f"Invalid value for 'side': {side}. Expected 'right' or 'left'." raise ValueError(error_msg) @@ -530,6 +585,11 @@ def from_numpy(ndarray: npt.NDArray[Any]) -> tf.Tensor: return tf.constant(ndarray) +@numeric.as_numpy_tensor.register +def as_numpy_tensor(a: tf.Tensor) -> npt.NDArray[Any]: + return a.numpy() + + @numeric.log2.register def _(a: tf.Tensor) -> tf.Tensor: with tf.device(a.device): diff --git a/nncf/tensor/tensor.py b/nncf/tensor/tensor.py index 9ccd841b164..c1d200714c9 100644 --- a/nncf/tensor/tensor.py +++ b/nncf/tensor/tensor.py @@ -85,6 +85,10 @@ def __len__(self) -> int: # built-in operations def __add__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: + if self.backend == TensorBackend.tf: + from nncf.tensor.functions import numeric + + return numeric.add(self, other) return Tensor(self.data + unwrap_tensor_data(other)) def __radd__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: @@ -95,6 +99,10 @@ def __iadd__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: return self def __sub__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: + if self.backend == TensorBackend.tf: + from nncf.tensor.functions import numeric + + return numeric.subtract(self, other) return Tensor(self.data - unwrap_tensor_data(other)) def __rsub__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: @@ -105,6 +113,10 @@ def __isub__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: return self def __mul__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: + if self.backend == TensorBackend.tf: + from nncf.tensor.functions import numeric + + return numeric.multiply(self, other) return Tensor(self.data * unwrap_tensor_data(other)) def __rmul__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: @@ -115,6 +127,10 @@ def __imul__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: return self def __pow__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: + if self.backend == TensorBackend.tf: + from nncf.tensor.functions import numeric + + return numeric.power(self, other) return Tensor(self.data ** unwrap_tensor_data(other)) def __rpow__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: @@ -148,6 +164,10 @@ def __matmul__(self, other: Union[Tensor, T_NUMBER]) -> Tensor: return Tensor(self.data @ unwrap_tensor_data(other)) def __neg__(self) -> Tensor: + if self.backend == TensorBackend.tf: + from nncf.tensor.functions import numeric + + return numeric.neg(self) return Tensor(-self.data) # Comparison operators From 911382cef5ec0bfc2b4277b7d63b52d8edbdf113 Mon Sep 17 00:00:00 2001 From: darshil929 Date: Mon, 24 Mar 2025 18:10:48 +0530 Subject: [PATCH 17/17] add backend implementation for float type in tensorflow tensor --- nncf/tensor/functions/tf_numeric.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index 82bfecbed6d..df114b2d41a 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -61,6 +61,11 @@ def _(a: tf.Tensor) -> TensorBackend: return TensorBackend.tf +@numeric.backend.register +def _(a: float) -> TensorBackend: + return TensorBackend.numpy + + @numeric.squeeze.register def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: with tf.device(a.device):