@@ -61,10 +61,10 @@ def test_extract_model(model_cls, input_node_name, output_node_name):
61
61
62
62
model = wrap_model (model_cls ().eval (), example_input = example_input , trace_parameters = True )
63
63
extracted_module = extract_model (model , [input_node_name ], [output_node_name ])
64
- with torch . no_grad ():
65
- ret1 = model (example_input )
66
- ret2 = extracted_module ( example_input )
67
- assert torch .any (torch .isclose (ret1 , ret2 ))
64
+ ret1 = model ( example_input )
65
+ ret2 = extracted_module (example_input )
66
+ assert not ret2 . grad_fn
67
+ assert torch .any (torch .isclose (ret1 , ret2 ))
68
68
69
69
70
70
@pytest .mark .parametrize (
@@ -122,10 +122,11 @@ def test_extract_model_for_node_with_fq(model_cls, input_node_name, output_node_
122
122
q_model = transformer .transform (layout )
123
123
124
124
extracted_module = extract_model (model , [input_node_name ], [output_node_name ])
125
- with torch .no_grad ():
126
- ret1 = q_model (example_input )
127
- ret2 = extracted_module (example_input )
128
- assert torch .all (torch .isclose (ret1 , ret2 ))
125
+
126
+ ret1 = q_model (example_input )
127
+ ret2 = extracted_module (example_input )
128
+ assert torch .all (torch .isclose (ret1 , ret2 ))
129
+ assert not ret2 .grad_fn
129
130
130
131
extracted_fn = extracted_module
131
132
if isinstance (extracted_fn , nn .Sequential ):
0 commit comments