Skip to content

Commit 99deedf

Browse files
fatcat-zpytorchmergebot
authored andcommitted
[ONNX] Describe memory usage of TorchDynamo-based exporter. (pytorch#139388)
Add a new documentation to show one memory usage benefit brought by TorchDynamo-based ONNX exporter. Also add a unit test to make sure TorchDynamo-based ONNX exporter works well under FakeTensorMode. Pull Request resolved: pytorch#139388 Approved by: https://github.com/xadupre
1 parent d603401 commit 99deedf

4 files changed

+119
-0
lines changed
Loading
Loading

docs/source/onnx_dynamo.rst

+8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ The resulting FX Graph is then polished before it is finally translated into an
2020
The main advantage of this approach is that the `FX graph <https://pytorch.org/docs/stable/fx.html>`_ is captured using
2121
bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.
2222

23+
In addition, during the export process, memory usage is significantly reduced compared to the TorchScript-enabled exporter.
24+
See the :doc:`documentation <onnx_dynamo_memory_usage>` for more information.
25+
2326
The exporter is designed to be modular and extensible. It is composed of the following components:
2427

2528
- **ONNX Exporter**: :class:`Exporter` main class that orchestrates the export process.
@@ -149,6 +152,11 @@ The main advantages are:
149152

150153
generated/onnx_dynamo_diagnostics_rules/*
151154

155+
.. toctree::
156+
:hidden:
157+
158+
onnx_dynamo_memory_usage
159+
152160
API Reference
153161
-------------
154162

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
Understanding TorchDynamo-based ONNX Exporter Memory Usage
2+
==========================================================
3+
The previous TorchScript-based ONNX exporter would execute the model once to trace its execution, which could cause it to run out of
4+
memory on your GPU if the model's memory requirements exceeded the available GPU memory. This issue has been addressed with the new
5+
TorchDynamo-based ONNX exporter.
6+
7+
The TorchDynamo-based ONNX exporter leverages `FakeTensorMode <https://pytorch.org/docs/stable/torch.compiler_fake_tensor.html>`_ to
8+
avoid performing actual tensor computations during the export process. This approach results in significantly lower memory usage
9+
compared to the TorchScript-based ONNX exporter.
10+
11+
Below is an example demonstrating the memory usage difference between TorchScript-based and TorchDynamo-based ONNX exporters.
12+
In this example, we use the HighResNet model from MONAI. Before proceeding, please install it from PyPI:
13+
14+
.. code-block:: bash
15+
16+
pip install monai
17+
18+
19+
PyTorch offers a tool for capturing and visualizing memory usage traces. We will use this tool to record the memory usage of the two
20+
exporters during the export process and compare the results. You can find more details about this tool on
21+
`Understanding CUDA Memory Usage <https://pytorch.org/docs/stable/torch_cuda_memory.html>`_.
22+
23+
24+
TorchScript-based exporter
25+
==========================
26+
The code below could be run to generate a snapshot file which records the state of allocated CUDA memory during the export process.
27+
28+
.. code-block:: python
29+
30+
import torch
31+
32+
from torch.onnx.utils import export
33+
from monai.networks.nets import (
34+
HighResNet,
35+
)
36+
37+
torch.cuda.memory._record_memory_history()
38+
39+
model = HighResNet(
40+
spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
41+
).eval()
42+
43+
model = model.to("cuda")
44+
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")
45+
46+
with torch.no_grad():
47+
export(
48+
model,
49+
data,
50+
"torchscript_exporter_highresnet.onnx",
51+
)
52+
53+
snapshot_name = f"torchscript_exporter_example.pickle"
54+
print(f"generate {snapshot_name}")
55+
56+
torch.cuda.memory._dump_snapshot(snapshot_name)
57+
print(f"Export is done.")
58+
59+
Open `pytorch.org/memory_viz <https://pytorch.org/memory_viz>`_ and drag/drop the generated pickled snapshot file into the visualizer.
60+
The memory usage is described as below:
61+
62+
.. image:: _static/img/onnx/torch_script_exporter_memory_usage.png
63+
64+
65+
By this figure, we can see the memory usage peak is above 2.8GB.
66+
67+
68+
TorchDynamo-based exporter
69+
==========================
70+
71+
The code below could be run to generate a snapshot file which records the state of allocated CUDA memory during the export process.
72+
73+
.. code-block:: python
74+
75+
import torch
76+
77+
from monai.networks.nets import (
78+
HighResNet,
79+
)
80+
81+
torch.cuda.memory._record_memory_history()
82+
83+
model = HighResNet(
84+
spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
85+
).eval()
86+
87+
model = model.to("cuda")
88+
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")
89+
90+
with torch.no_grad():
91+
onnx_program = torch.onnx.export(
92+
model,
93+
data,
94+
"test_faketensor.onnx",
95+
dynamo=True,
96+
)
97+
98+
snapshot_name = f"torchdynamo_exporter_example.pickle"
99+
print(f"generate {snapshot_name}")
100+
101+
torch.cuda.memory._dump_snapshot(snapshot_name)
102+
print(f"Export is done.")
103+
104+
Open `pytorch.org/memory_viz <https://pytorch.org/memory_viz>`_ and drag/drop the generated pickled snapshot file into the visualizer.
105+
The memeory usage is described as below:
106+
107+
.. image:: _static/img/onnx/torch_dynamo_exporter_memory_usage.png
108+
109+
110+
By this figure, we can see the memory usage peak is only around 45MB. Comparing to the memory usage peak of TorchScript-based exporter,
111+
it reduces 98% memory usage.

0 commit comments

Comments
 (0)