Skip to content

Commit 99f0c44

Browse files
[PT] disable grad for extracted module (#3266)
### Changes Disable gradient for extracted modules ### Reason for changes Possible memory leaks ### Tests manual/job/post_training_quantization/606
1 parent c24bf74 commit 99f0c44

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

nncf/torch/extractor.py

+2
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def extract_bn(node: NNCFNode, model: NNCFNetwork) -> Optional[Union[nn.BatchNor
153153
for name, _ in chain(extracted_bn.named_parameters(), extracted_bn.named_buffers()):
154154
setattr(extracted_bn, name, deepcopy(getattr(bn_module, name)))
155155
extracted_bn.eval()
156+
extracted_bn.weight.requires_grad = False
157+
extracted_bn.bias.requires_grad = False
156158
return extracted_bn
157159

158160

tests/torch/test_extractor.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ def test_extract_model(model_cls, input_node_name, output_node_name):
6161

6262
model = wrap_model(model_cls().eval(), example_input=example_input, trace_parameters=True)
6363
extracted_module = extract_model(model, [input_node_name], [output_node_name])
64-
with torch.no_grad():
65-
ret1 = model(example_input)
66-
ret2 = extracted_module(example_input)
67-
assert torch.any(torch.isclose(ret1, ret2))
64+
ret1 = model(example_input)
65+
ret2 = extracted_module(example_input)
66+
assert not ret2.grad_fn
67+
assert torch.any(torch.isclose(ret1, ret2))
6868

6969

7070
@pytest.mark.parametrize(
@@ -122,10 +122,11 @@ def test_extract_model_for_node_with_fq(model_cls, input_node_name, output_node_
122122
q_model = transformer.transform(layout)
123123

124124
extracted_module = extract_model(model, [input_node_name], [output_node_name])
125-
with torch.no_grad():
126-
ret1 = q_model(example_input)
127-
ret2 = extracted_module(example_input)
128-
assert torch.all(torch.isclose(ret1, ret2))
125+
126+
ret1 = q_model(example_input)
127+
ret2 = extracted_module(example_input)
128+
assert torch.all(torch.isclose(ret1, ret2))
129+
assert not ret2.grad_fn
129130

130131
extracted_fn = extracted_module
131132
if isinstance(extracted_fn, nn.Sequential):

0 commit comments

Comments
 (0)