Skip to content

Commit 5588b0a

Browse files
committed
fix tensorflow tensor implementation with norm fixes and device preservation
1 parent 10a38ad commit 5588b0a

File tree

5 files changed

+210
-38
lines changed

5 files changed

+210
-38
lines changed

nncf/tensor/functions/numeric.py

+32
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@ def abs(a: Tensor) -> Tensor:
105105
"""
106106

107107

108+
@tensor_dispatcher
109+
def neg(a: Tensor) -> Tensor:
110+
"""
111+
Numerical negative, element-wise.
112+
113+
:param a: The input tensor.
114+
:return: A tensor containing the negative value of each element in a.
115+
"""
116+
117+
108118
@tensor_dispatcher
109119
def astype(a: Tensor, dtype: TensorDataType) -> Tensor:
110120
"""
@@ -493,6 +503,28 @@ def sum(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims:
493503
"""
494504

495505

506+
@tensor_dispatcher
507+
def add(x1: Tensor, x2: Union[Tensor, float]) -> Tensor:
508+
"""
509+
Add two tensors element-wise.
510+
511+
:param x1: The first input tensor.
512+
:param x2: The second input tensor or number.
513+
:return: The sum of x1 and x2, element-wise.
514+
"""
515+
516+
517+
@tensor_dispatcher
518+
def subtract(x1: Tensor, x2: Union[Tensor, float]) -> Tensor:
519+
"""
520+
Subtract two tensors element-wise.
521+
522+
:param x1: The first input tensor.
523+
:param x2: The second input tensor or number.
524+
:return: The result of x1 - x2, element-wise.
525+
"""
526+
527+
496528
@tensor_dispatcher
497529
def multiply(x1: Tensor, x2: Union[Tensor, float]) -> Tensor:
498530
"""

nncf/tensor/functions/tf_io.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,26 @@
1818

1919
from nncf.tensor import TensorDeviceType
2020
from nncf.tensor.functions import io as io
21+
from nncf.tensor.functions.tf_numeric import DEVICE_MAP
2122

2223

2324
def load_file(file_path: Path, *, device: Optional[TensorDeviceType] = None) -> Dict[str, tf.Tensor]:
24-
return tf_load_file(file_path)
25+
loaded_tensors = tf_load_file(file_path)
26+
27+
if device is not None:
28+
device_str = DEVICE_MAP[device]
29+
with tf.device(device_str):
30+
loaded_tensors = {k: tf.identity(v) for k, v in loaded_tensors.items()}
31+
32+
return loaded_tensors
2533

2634

2735
@io.save_file.register
2836
def _(data: Dict[str, tf.Tensor], file_path: Path) -> None:
37+
if file_path.is_symlink():
38+
from nncf.errors import ValidationError
39+
40+
error_msg = "Cannot save tensor to a symbolic link"
41+
raise ValidationError(error_msg)
42+
2943
return tf_save_file(data, file_path)

nncf/tensor/functions/tf_linalg.py

+48-9
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@ def _(
2424
axis: Optional[Union[int, Tuple[int, ...]]] = None,
2525
keepdims: bool = False,
2626
) -> tf.Tensor:
27-
if ord is None:
28-
ord = "euclidean"
2927
rank = tf.rank(a)
28+
29+
if ord is None:
30+
if axis is None and rank == 2:
31+
ord = "fro"
32+
else:
33+
ord = 2
34+
3035
if rank == 2 and axis is None:
3136
axis = (0, 1)
3237

@@ -49,41 +54,75 @@ def _(
4954
if rank != 2:
5055
error_msg = "ord=-1 is only supported for 2D tensors"
5156
raise ValueError(error_msg)
52-
return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims)
57+
result = tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims)
58+
if keepdims:
59+
result = tf.reshape(result, [1, 1])
60+
return result
5361

5462
if ord == 1 and isinstance(axis, tuple) and len(axis) != 1:
5563
if rank != 2:
5664
error_msg = "ord=1 is only supported for 2D tensors"
5765
raise ValueError(error_msg)
58-
return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims)
66+
result = tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims)
67+
if keepdims:
68+
result = tf.reshape(result, [1, 1])
69+
return result
5970

6071
if ord == -2 and isinstance(axis, tuple) and len(axis) != 1:
6172
if rank != 2:
6273
error_msg = "ord=-2 is only supported for 2D tensors"
6374
raise ValueError(error_msg)
6475
s = tf.linalg.svd(a, compute_uv=False)
65-
return tf.reduce_min(s, axis=-1)
76+
result = tf.reduce_min(s, axis=-1)
77+
if keepdims:
78+
result = tf.reshape(result, [1, 1])
79+
return result
6680

6781
if ord == 2 and isinstance(axis, tuple) and len(axis) != 1:
6882
if rank != 2:
6983
error_msg = "ord=2 is only supported for 2D tensors"
7084
raise ValueError(error_msg)
7185
s = tf.linalg.svd(a, compute_uv=False)
72-
return tf.reduce_max(s, axis=-1)
86+
result = tf.reduce_max(s, axis=-1)
87+
if keepdims:
88+
result = tf.reshape(result, [1, 1])
89+
return result
7390

7491
if ord == float("inf") and isinstance(axis, tuple) and len(axis) != 1:
7592
if rank != 2:
7693
error_msg = "ord=inf is only supported for 2D tensors"
7794
raise ValueError(error_msg)
78-
return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims)
95+
result = tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims)
96+
if keepdims:
97+
result = tf.reshape(result, [1, 1])
98+
return result
7999

80100
if ord == -float("inf") and isinstance(axis, tuple) and len(axis) != 1:
81101
if rank != 2:
82102
error_msg = "ord=-inf is only supported for 2D tensors"
83103
raise ValueError(error_msg)
84-
return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims)
104+
result = tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims)
105+
if keepdims:
106+
result = tf.reshape(result, [1, 1])
107+
return result
85108

86-
return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims)
109+
try:
110+
return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims)
111+
except (TypeError, ValueError):
112+
if axis is not None:
113+
if ord == 2:
114+
squared = tf.square(a)
115+
sum_squares = tf.reduce_sum(squared, axis=axis, keepdims=keepdims)
116+
return tf.sqrt(sum_squares)
117+
elif ord == 1:
118+
return tf.reduce_sum(tf.abs(a), axis=axis, keepdims=keepdims)
119+
elif ord == float("inf"):
120+
return tf.reduce_max(tf.abs(a), axis=axis, keepdims=keepdims)
121+
elif ord == -float("inf"):
122+
return tf.reduce_min(tf.abs(a), axis=axis, keepdims=keepdims)
123+
124+
error_msg = f"Unsupported combination of ord={ord} and axis={axis}"
125+
raise ValueError(error_msg)
87126

88127

89128
@linalg.cholesky.register

0 commit comments

Comments
 (0)