Skip to content

Commit bd266dc

Browse files
authored
[PT FE] Support torch==2.6.0 (openvinotoolkit#29196)
### Details: - *Support `torch==2.6.0`* ### Tickets: - *CVS-162009* --------- Signed-off-by: Maxim Vafin <maxim.vafin@intel.com>
1 parent 3444a4a commit bd266dc

File tree

7 files changed

+46
-34
lines changed

7 files changed

+46
-34
lines changed

.github/workflows/job_pytorch_layer_tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ env:
3030
jobs:
3131
PyTorch_Layer_Tests:
3232
name: PyTorch Layer Tests
33-
timeout-minutes: 40
33+
timeout-minutes: 50
3434
runs-on: ${{ inputs.runner }}
3535
container: ${{ fromJSON(inputs.container) }}
3636
defaults:

src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py

+28
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class TorchFXPythonDecoder (BaseFXDecoder):
166166
Decoder for PyTorch FX GraphModule and Node objects to OpenVINO IR.
167167
"""
168168

169+
_decomp_table = None
170+
169171
def __init__(self, pt_module, fx_gm=None, nodes=None,
170172
mark_node_callback=None, input_shapes=[], input_types=[], dynamic_shapes=False):
171173
super().__init__(mark_node_callback)
@@ -230,6 +232,32 @@ def __init__(self, pt_module, fx_gm=None, nodes=None,
230232
self.input_types.append(
231233
BaseFXDecoder.get_type_for_value(arg))
232234

235+
@classmethod
236+
def from_exported_program(cls, exported_program: torch.export.ExportedProgram) -> 'TorchFXPythonDecoder':
237+
"""
238+
Create a TorchFXPythonDecoder instance from an exported PyTorch program.
239+
"""
240+
from packaging import version
241+
if version.parse(torch.__version__) >= version.parse("2.6"):
242+
if cls._decomp_table is None:
243+
from torch.export.decomp_utils import CustomDecompTable
244+
from openvino.frontend.pytorch.torchdynamo.decompositions import ops_to_not_decompose
245+
cls._decomp_table = CustomDecompTable()
246+
for op in ops_to_not_decompose():
247+
try:
248+
cls._decomp_table.pop(op)
249+
except KeyError as e:
250+
logging.warning("Operation %s not found in decomp table", op, exc_info=e)
251+
exported_program = exported_program.run_decompositions(cls._decomp_table)
252+
elif version.parse(torch.__version__) >= version.parse("2.2"):
253+
from torch._decomp import get_decompositions
254+
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
255+
decomp = get_decompositions(get_export_decomposition_list())
256+
exported_program = exported_program.run_decompositions(decomp_table=decomp)
257+
gm = exported_program.module()
258+
logger.debug(gm.code)
259+
return cls(gm, dynamic_shapes=True)
260+
233261
@staticmethod
234262
def get_found_shape(value) -> str:
235263
# If input is a tensor, read the shape from meta data

src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py

+8
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,11 @@ def get_export_decomposition_list():
297297
except ImportError:
298298
pass
299299
return decomp
300+
301+
302+
def ops_to_not_decompose():
303+
# List of operations that shouldn't be decomposed
304+
return [
305+
torch.ops.aten.col2im.default,
306+
torch.ops.aten.upsample_nearest2d.default,
307+
]

tests/layer_tests/pytorch_tests/test_arange.py

-6
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def forward(self, x, y, z, d):
109109

110110
@pytest.mark.nightly
111111
@pytest.mark.precommit_torch_export
112-
@pytest.mark.precommit_fx_backend
113112
@pytest.mark.parametrize("dtype", [None,
114113
skip_if_export("float32"),
115114
skip_if_export("float64"),
@@ -124,7 +123,6 @@ def test_arange_end_only(self, dtype, end, use_out, ie_device, precision, ir_ver
124123
kwargs_to_prepare_input={"end": end})
125124

126125
@pytest.mark.nightly
127-
@pytest.mark.precommit_fx_backend
128126
@pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"])
129127
@pytest.mark.parametrize("start,end", [(0, 1), (-1, 1), (1, 5), (0.5, 2.5)])
130128
def test_arange_start_end(self, dtype, end, start, ie_device, precision, ir_version):
@@ -133,7 +131,6 @@ def test_arange_start_end(self, dtype, end, start, ie_device, precision, ir_vers
133131

134132
@pytest.mark.nightly
135133
@pytest.mark.precommit
136-
@pytest.mark.precommit_fx_backend
137134
@pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"])
138135
@pytest.mark.parametrize("start,end,step", [(0, 1, 1), (-2, 1, 1.25), (1, -5, -1), (1, 10, 2), (-1, -5, -2)])
139136
def test_arange_start_end_step(self, dtype, end, start, step, ie_device, precision, ir_version):
@@ -142,7 +139,6 @@ def test_arange_start_end_step(self, dtype, end, start, step, ie_device, precisi
142139

143140
@pytest.mark.nightly
144141
@pytest.mark.precommit_torch_export
145-
@pytest.mark.precommit_fx_backend
146142
@pytest.mark.parametrize("dtype", [skip_check(None),
147143
skip_if_export("float32"),
148144
skip_if_export("float64"),
@@ -156,7 +152,6 @@ def test_arange_end_only_with_prim_dtype(self, dtype, end, ie_device, precision,
156152
kwargs_to_prepare_input={"end": end, "ref_dtype": dtype})
157153

158154
@pytest.mark.nightly
159-
@pytest.mark.precommit_fx_backend
160155
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8"])
161156
@pytest.mark.parametrize("start,end", [(0, 1), (-1, 1), (1, 5), (0.5, 2.5)])
162157
def test_arange_start_end_with_prim_dtype(self, dtype, end, start, ie_device, precision, ir_version):
@@ -165,7 +160,6 @@ def test_arange_start_end_with_prim_dtype(self, dtype, end, start, ie_device, pr
165160

166161
@pytest.mark.nightly
167162
@pytest.mark.precommit
168-
@pytest.mark.precommit_fx_backend
169163
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8"])
170164
@pytest.mark.parametrize("start,end,step", [(0, 1, 1), (-2, 1, 1.25), (1, -5, -1), (1, 10, 2), (-1, -5, -2)])
171165
def test_arange_start_end_step_with_prim_dtype(self, dtype, end, start, step, ie_device, precision, ir_version):

tests/layer_tests/pytorch_tests/test_trilu.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,15 @@ def forward(self, x):
3535

3636
return aten_trilu(pt_op, diagonal), ref_net, f"aten::{op}"
3737

38-
@pytest.mark.parametrize("input_shape", [(5, 5), (6, 4), (4, 6)])
3938
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8", "uint8", "bool"])
4039
@pytest.mark.parametrize("diagonal", [0, 1, 2, -1, -2])
4140
@pytest.mark.parametrize("op", ["triu", "tril"])
4241
@pytest.mark.nightly
4342
@pytest.mark.precommit
4443
@pytest.mark.precommit_fx_backend
45-
def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version):
44+
def test_trilu(self, dtype, diagonal, op, ie_device, precision, ir_version):
4645
self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version,
47-
kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype})
46+
kwargs_to_prepare_input={"shape": (4, 6), "dtype": dtype})
4847

4948

5049
class TestTriuTrilTensor(PytorchLayerTest):
@@ -84,13 +83,12 @@ def triu_(self, x):
8483

8584
return aten_trilu(op, diagonal), ref_net, f"aten::{op}"
8685

87-
@pytest.mark.parametrize("input_shape", [(5, 5), (6, 4), (4, 6)])
8886
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8", "uint8", "bool"])
8987
@pytest.mark.parametrize("diagonal", [0, 1, 2, -1, -2])
9088
@pytest.mark.parametrize("op", ["triu", "tril", "triu_", "tril_"])
9189
@pytest.mark.nightly
9290
@pytest.mark.precommit
9391
@pytest.mark.precommit_fx_backend
94-
def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version):
92+
def test_trilu(self, dtype, diagonal, op, ie_device, precision, ir_version):
9593
self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version,
96-
kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype})
94+
kwargs_to_prepare_input={"shape": (4, 6), "dtype": dtype})

tests/requirements_pytorch

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
# optimum still requires numpy<2.0.0
44
numpy==1.26.4; python_version < "3.12"
55
numpy==2.1.1; python_version >= "3.12"
6-
torch==2.5.1; platform_system != "Darwin" or platform_machine != "x86_64"
6+
torch==2.6.0; platform_system != "Darwin" or platform_machine != "x86_64"
77
torch==2.2.2; platform_system == "Darwin" and platform_machine == "x86_64"
88
--extra-index-url https://download.pytorch.org/whl/cpu
99

10-
torchvision==0.20.1; platform_system != "Darwin" or platform_machine != "x86_64"
10+
torchvision==0.21.0; platform_system != "Darwin" or platform_machine != "x86_64"
1111
torchvision==0.17.2; platform_system == "Darwin" and platform_machine == "x86_64"
12-
torchaudio==2.5.1; platform_system != "Darwin" or platform_machine != "x86_64"
12+
torchaudio==2.6.0; platform_system != "Darwin" or platform_machine != "x86_64"
1313
torchaudio==2.2.2; platform_system == "Darwin" and platform_machine == "x86_64"
1414
# before updating transformers version, make sure no tests (esp. sdpa2pa) are failing
1515
transformers==4.47.1

tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py

+2-18
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,6 @@ def extract_module_extensions(args):
2121
return {extension.module: extension for extension in extensions if isinstance(extension, ModuleExtension)}
2222

2323

24-
def get_decoder_for_exported_program(model):
25-
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
26-
import torch
27-
28-
from packaging import version
29-
if version.parse(torch.__version__) >= version.parse("2.2"):
30-
from torch._decomp import get_decompositions
31-
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
32-
decomp = get_decompositions(get_export_decomposition_list())
33-
model = model.run_decompositions(decomp_table=decomp)
34-
gm = model.module()
35-
log.debug(gm.code)
36-
decoder = TorchFXPythonDecoder(gm, dynamic_shapes=True)
37-
return decoder
38-
39-
4024
def get_pytorch_decoder(model, example_inputs, args):
4125
try:
4226
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
@@ -65,7 +49,7 @@ def get_pytorch_decoder(model, example_inputs, args):
6549
inputs = prepare_torch_inputs(example_inputs)
6650
if not isinstance(model, (TorchScriptPythonDecoder, TorchFXPythonDecoder)):
6751
if hasattr(torch, "export") and isinstance(model, (torch.export.ExportedProgram)):
68-
decoder = get_decoder_for_exported_program(model)
52+
decoder = TorchFXPythonDecoder.from_exported_program(model)
6953
else:
7054
decoder = TorchScriptPythonDecoder(
7155
model,
@@ -123,7 +107,7 @@ def get_pytorch_decoder_for_model_on_disk(argv, args):
123107
try:
124108
exported_program = torch.export.load(input_model)
125109
if hasattr(torch, "export") and isinstance(exported_program, (torch.export.ExportedProgram)):
126-
argv.input_model = get_decoder_for_exported_program(exported_program)
110+
argv.input_model = TorchFXPythonDecoder.from_exported_program(exported_program)
127111
argv.framework = 'pytorch'
128112
return True
129113
except:

0 commit comments

Comments
 (0)