Skip to content

Commit b4cbfe7

Browse files
Rework tensor dispatcher (openvinotoolkit#3306)
### Changes Introduce tensor_dispatcher instead of functools.singledispatch - Simplified adding a new function - Dispatch mechanism based on signature - Check names of arguments and annotation - Enable mypy for `nncf/tensor` - Keep signature of function ![image](https://github.com/user-attachments/assets/bc53a27e-49ab-4cd6-816b-07e12e431f60) ### Reason for changes No signature of functions ![image](https://github.com/user-attachments/assets/7bbb4ec0-4323-49b0-bd2b-dd3ea8fe149b) Impossible use mypy without cast.
1 parent a8ba008 commit b4cbfe7

18 files changed

+878
-736
lines changed

nncf/common/tensor_statistics/statistics.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from abc import ABC
1313
from abc import abstractmethod
1414
from collections import Counter
15-
from typing import Any, Dict, TypeVar, cast
15+
from typing import Any, Dict, TypeVar
1616

1717
from nncf.tensor import Tensor
1818
from nncf.tensor import functions as fns
@@ -27,7 +27,7 @@ class TensorStatistic(ABC):
2727

2828
@staticmethod
2929
def tensor_eq(tensor1: Tensor, tensor2: Tensor, rtol: float = 1e-6) -> bool:
30-
return cast(bool, fns.allclose(tensor1, tensor2))
30+
return fns.allclose(tensor1, tensor2, rtol=rtol)
3131

3232
@abstractmethod
3333
def __eq__(self, other: Any) -> bool:

nncf/tensor/README.md

+8-18
Original file line numberDiff line numberDiff line change
@@ -123,45 +123,35 @@ tensor_a[0:2] # Tensor(array([[1],[2]]))
123123
2. Add function to functions module
124124

125125
```python
126-
@functools.singledispatch
127-
def foo(a: TTensor, arg1: Type) -> TTensor:
126+
@tensor_dispatcher
127+
def foo(a: Tensor, arg1: Type) -> Tensor:
128128
"""
129129
__description__
130130
131131
:param a: The input tensor.
132132
:param arg1: __description__
133133
:return: __description__
134134
"""
135-
if isinstance(a, tensor.Tensor):
136-
return tensor.Tensor(foo(a.data, axis))
137-
return NotImplemented(f"Function `foo` is not implemented for {type(a)}")
138135
```
139136

140-
**NOTE** For the case when the first argument has type `List[Tensor]`, use the `_dispatch_list` function. This function dispatches function by first element in the first argument.
141-
142-
```python
143-
@functools.singledispatch
144-
def foo(x: List[Tensor], axis: int = 0) -> Tensor:
145-
if isinstance(x, List):
146-
unwrapped_x = [i.data for i in x]
147-
return Tensor(_dispatch_list(foo, unwrapped_x, axis=axis))
148-
raise NotImplementedError(f"Function `foo` is not implemented for {type(x)}")
149-
```
137+
**NOTE** The wrapping of the return value depends on the return type annotation of the function.
138+
If return type collect `Tensor` than return value will be wrapped in `Tensor` class according annotation,
139+
otherwise return value will be returned as is.
150140

151141
3. Add backend specific implementation of method to corresponding module:
152142

153143
- `functions/numpy_*.py`
154144

155145
```python
156-
@_register_numpy_types(fns.foo)
157-
def _(a: TType, arg1: Type) -> np.ndarray:
146+
@fns.foo.register
147+
def _(a: T_NUMPY_ARRAY, arg1: Type) -> T_NUMPY_ARRAY:
158148
return np.foo(a, arg1)
159149
```
160150

161151
- `functions/torch_*.py`
162152

163153
```python
164-
@fns.foo.register(torch.Tensor)
154+
@fns.foo.register
165155
def _(a: torch.Tensor, arg1: Type) -> torch.Tensor:
166156
return torch.foo(a, arg1)
167157
```

nncf/tensor/definitions.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
from dataclasses import dataclass
1313
from enum import Enum
1414
from enum import auto
15+
from typing import Optional, Tuple, Union
16+
17+
T_SHAPE_ARRAY = Tuple[int, ...]
18+
T_SHAPE = Union[int, T_SHAPE_ARRAY]
19+
T_AXIS = Optional[T_SHAPE]
20+
T_NUMBER = Union[int, float, bool]
1521

1622

1723
class TensorBackend(Enum):
@@ -40,7 +46,7 @@ class TensorDataType(Enum):
4046
uint4 = auto()
4147
int4 = auto()
4248

43-
def is_float(self):
49+
def is_float(self) -> bool:
4450
"""
4551
:return: True if the tensor data type is a floating-point type, else False.
4652
"""

nncf/tensor/functions/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from nncf.tensor.functions.numeric import zeros_like as zeros_like
6868

6969

70-
def _initialize_backends():
70+
def _initialize_backends() -> None:
7171
import contextlib
7272

7373
import nncf.tensor.functions.numpy_io

0 commit comments

Comments
 (0)