Skip to content

Commit 8d3b0da

Browse files
committed
address reviewer feedback: update file path types to Path, rollback version changes, and improve tensor device handling
1 parent 22a9580 commit 8d3b0da

File tree

5 files changed

+8
-8
lines changed

5 files changed

+8
-8
lines changed

nncf/tensor/functions/tf_io.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from pathlib import Path
1213
from typing import Dict, Optional
1314

1415
import tensorflow as tf
@@ -19,10 +20,10 @@
1920
from nncf.tensor.functions import io as io
2021

2122

22-
def load_file(file_path: str, *, device: Optional[TensorDeviceType] = None) -> Dict[str, tf.Tensor]:
23+
def load_file(file_path: Path, *, device: Optional[TensorDeviceType] = None) -> Dict[str, tf.Tensor]:
2324
return tf_load_file(file_path)
2425

2526

2627
@io.save_file.register
27-
def _(data: Dict[str, tf.Tensor], file_path: str) -> None:
28+
def _(data: Dict[str, tf.Tensor], file_path: Path) -> None:
2829
return tf_save_file(data, file_path)

nncf/tensor/functions/tf_numeric.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,13 @@ def _(
293293
@numeric._binary_op_nowarn.register(tf.Tensor)
294294
def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor:
295295
with tf.device(a.device):
296-
return operator_fn(a, b)
296+
return tf.identity(operator_fn(a, b))
297297

298298

299299
@numeric._binary_reverse_op_nowarn.register(tf.Tensor)
300300
def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor:
301301
with tf.device(a.device):
302-
return operator_fn(b, a)
302+
return tf.identity(operator_fn(b, a))
303303

304304

305305
@numeric.clip.register(tf.Tensor)

nncf/tensor/tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __floordiv__(self, other: Union[Tensor, T_NUMBER]) -> Tensor:
140140
def __rfloordiv__(self, other: Union[Tensor, T_NUMBER]) -> Tensor:
141141
return cast(Tensor, _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv))
142142

143-
def __ifloordiv__(self, other: Union[Tensor, float]) -> Tensor:
143+
def __ifloordiv__(self, other: Union[Tensor, T_NUMBER]) -> Tensor:
144144
self._data //= unwrap_tensor_data(other)
145145
return self
146146

nncf/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
__version__ = "2.16.0.dev0+b1af5d11dirty"
12+
__version__ = "2.16.0"
1313

1414

1515
BKC_TORCH_SPEC = "==2.6.*"

tests/cross_fw/test_templates/template_test_nncf_tensor.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ def test_operators_tensor(self, op_name):
113113
assert res.dtype == res_nncf.data.dtype
114114
assert all(res == res_nncf.data)
115115
assert isinstance(res_nncf, Tensor)
116-
if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU):
117-
assert res_nncf.device == nncf_tensor_a.device
116+
assert res_nncf.device == nncf_tensor_a.device
118117

119118
@pytest.mark.parametrize("op_name", OPERATOR_MAP.keys())
120119
def test_operators_int(self, op_name):

0 commit comments

Comments
 (0)