Skip to content

Commit b0897b3

Browse files
[PT] detach tensor in get_const_data (#3199)
### Changes Return detached tensor from `get_const_data` function ### Reason for changes Memory leak
1 parent 34795fd commit b0897b3

File tree

3 files changed

+20
-11
lines changed

3 files changed

+20
-11
lines changed

nncf/quantization/algorithms/weight_compression/torch_backend.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
3939
from nncf.torch.graph.transformations.commands import PTTargetPoint
4040
from nncf.torch.model_graph_manager import find_const_node_in_constant_subgraph
41+
from nncf.torch.model_graph_manager import get_const_data
4142
from nncf.torch.model_graph_manager import get_const_node
4243
from nncf.torch.model_graph_manager import get_module_by_name
4344
from nncf.torch.model_graph_manager import split_const_name
@@ -173,10 +174,8 @@ def get_weight(
173174
) -> Tensor:
174175
weight_node = get_const_node(node_with_weight, weight_port_id, graph)
175176
weight_name = weight_node.layer_attributes.name
176-
module_name, weight_attr_name = split_const_name(weight_name)
177-
module = get_module_by_name(module_name, model)
178-
weight = getattr(module, weight_attr_name)
179-
if weight is None or not isinstance(weight, torch.nn.Parameter):
177+
weight = get_const_data(weight_node, model)
178+
if weight is None:
180179
raise nncf.InternalError(f"Could not find a torch.nn.Parameter in the model by name {weight_name}.")
181180

182181
return Tensor(weight)
@@ -222,10 +221,8 @@ def transform_model(
222221

223222
weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph)
224223
weight_name = weight_node.layer_attributes.name
225-
module_name, weight_attr_name = split_const_name(weight_name)
226-
module = get_module_by_name(module_name, model)
227-
weight = getattr(module, weight_attr_name)
228-
if weight is None or not isinstance(weight, torch.nn.Parameter):
224+
weight = get_const_data(weight_node, model)
225+
if weight is None:
229226
raise nncf.InternalError(f"Could not find a torch.nn.Parameter in the model by name {weight_name}.")
230227

231228
# calculates compressed weights and decompression parameters
@@ -264,7 +261,14 @@ def transform_model(
264261
packed_tensor = decompressor.pack_weight(compressed_weight.tensor.data)
265262

266263
# sets compressed tensor
264+
# TODO:(AlexanderDokuchaev): update set_const_data
267265
compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False)
266+
module_name, weight_attr_name = split_const_name(weight_name)
267+
module = get_module_by_name(module_name, model)
268+
weight = getattr(module, weight_attr_name)
269+
if not isinstance(weight, torch.nn.Parameter):
270+
raise nncf.InternalError(f"Weight is not a torch.nn.Parameter in the model by name {weight_name}.")
271+
268272
setattr(module, weight_attr_name, compressed_parameter)
269273

270274
consumer_nodes = graph.get_next_nodes(weight_node)

nncf/torch/model_graph_manager.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Mod
118118

119119
def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor:
120120
"""
121-
Retrieves a constant tensor associated with a given node.
121+
Retrieves a detached constant tensor associated with a given node.
122122
123123
:param const_node: The node associated with const data.
124124
:param model: The NNCFNetwork object.
@@ -129,8 +129,8 @@ def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor:
129129
module = get_module_by_name(module_name, model)
130130
data = getattr(module, const_attr_name)
131131
if isinstance(data, torch.nn.Parameter):
132-
return data.data
133-
return data
132+
return data.data.detach()
133+
return data.detach()
134134

135135

136136
def get_const_data_on_port(node: NNCFNode, port_id: int, model: NNCFNetwork) -> torch.Tensor:

tests/torch/test_model_graph_manager.py

+5
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,15 @@ def test_get_set_const_data():
239239
graph = model.nncf.get_graph()
240240
const_node = graph.get_node_by_name("conv.bias")
241241

242+
assert model.conv.bias.requires_grad
243+
242244
data = get_const_data(const_node, model)
245+
assert not data.requires_grad
243246
assert torch.all(model.conv.bias.data == data)
247+
244248
set_const_data(torch.ones_like(data), const_node, model)
245249
assert torch.all(model.conv.bias.data == torch.ones_like(data))
250+
assert model.conv.bias.requires_grad
246251

247252

248253
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)