Skip to content

Commit f71e368

Browse files
arunppsgpytorchmergebot
authored andcommitted
UFMT formatting on test/autograd test/ao test/cpp test/backends (pytorch#123369)
Partially addresses pytorch#123062 Ran lintrunner on - test/_test_bazel.py - test/ao - test/autograd test/backends test/benchmark_uitls test/conftest.py test/bottleneck_test test/cpp Pull Request resolved: pytorch#123369 Approved by: https://github.com/huydhn
1 parent de7edee commit f71e368

23 files changed

+1914
-1035
lines changed

.lintrunner.toml

-27
Original file line numberDiff line numberDiff line change
@@ -1014,33 +1014,6 @@ exclude_patterns = [
10141014
'test/_nvfuser/test_dynamo.py',
10151015
'test/_nvfuser/test_python_frontend.py',
10161016
'test/_nvfuser/test_torchscript.py',
1017-
'test/_test_bazel.py',
1018-
'test/ao/sparsity/test_activation_sparsifier.py',
1019-
'test/ao/sparsity/test_composability.py',
1020-
'test/ao/sparsity/test_data_scheduler.py',
1021-
'test/ao/sparsity/test_data_sparsifier.py',
1022-
'test/ao/sparsity/test_kernels.py',
1023-
'test/ao/sparsity/test_parametrization.py',
1024-
'test/ao/sparsity/test_qlinear_packed_params.py',
1025-
'test/ao/sparsity/test_scheduler.py',
1026-
'test/ao/sparsity/test_sparsifier.py',
1027-
'test/ao/sparsity/test_sparsity_utils.py',
1028-
'test/ao/sparsity/test_structured_sparsifier.py',
1029-
'test/autograd/test_complex.py',
1030-
'test/autograd/test_fallback.py',
1031-
'test/autograd/test_functional.py',
1032-
'test/backends/xeon/test_launch.py',
1033-
'test/benchmark_utils/test_benchmark_utils.py',
1034-
'test/bottleneck_test/test.py',
1035-
'test/bottleneck_test/test_args.py',
1036-
'test/bottleneck_test/test_cuda.py',
1037-
'test/conftest.py',
1038-
'test/cpp/__init__.py',
1039-
'test/cpp/aot_inductor/test.py',
1040-
'test/cpp/api/init_baseline.py',
1041-
'test/cpp/api/optim_baseline.py',
1042-
'test/cpp/jit/__init__.py',
1043-
'test/cpp/jit/tests_setup.py',
10441017
'test/cpp_api_parity/__init__.py',
10451018
'test/cpp_api_parity/functional_impl_check.py',
10461019
'test/cpp_api_parity/module_impl_check.py',

test/_test_bazel.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111

1212
import torch
1313

14+
1415
def test_sum() -> None:
15-
assert torch.eq(torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])).all()
16+
assert torch.eq(
17+
torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])
18+
).all()
1619

17-
def test_simple_compile_eager() -> None:
1820

21+
def test_simple_compile_eager() -> None:
1922
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2023
a = torch.sin(x)
2124
b = torch.cos(y)

test/ao/sparsity/test_activation_sparsifier.py

+87-58
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
# Owner(s): ["module: unknown"]
22

33
import copy
4-
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
54
import logging
5+
from typing import List
6+
67
import torch
7-
from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier import ActivationSparsifier
88
import torch.nn as nn
99
import torch.nn.functional as F
10+
from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier import (
11+
ActivationSparsifier,
12+
)
1013
from torch.ao.pruning.sparsifier.utils import module_to_fqn
11-
from typing import List
14+
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
1215

13-
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
16+
logging.basicConfig(
17+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
18+
)
1419

1520

1621
class Model(nn.Module):
@@ -45,7 +50,7 @@ def _check_constructor(self, activation_sparsifier, model, defaults, sparse_conf
4550
in the activation sparsifier
4651
"""
4752
sparsifier_defaults = activation_sparsifier.defaults
48-
combined_defaults = {**defaults, 'sparse_config': sparse_config}
53+
combined_defaults = {**defaults, "sparse_config": sparse_config}
4954

5055
# more keys are populated in activation sparsifier (eventhough they may be None)
5156
assert len(combined_defaults) <= len(activation_sparsifier.defaults)
@@ -54,7 +59,9 @@ def _check_constructor(self, activation_sparsifier, model, defaults, sparse_conf
5459
# all the keys in combined_defaults should be present in sparsifier defaults
5560
assert config == combined_defaults.get(key, None)
5661

57-
def _check_register_layer(self, activation_sparsifier, defaults, sparse_config, layer_args_list):
62+
def _check_register_layer(
63+
self, activation_sparsifier, defaults, sparse_config, layer_args_list
64+
):
5865
"""Checks if layers in the model are correctly mapped to it's arguments.
5966
6067
Args:
@@ -82,14 +89,14 @@ def _check_register_layer(self, activation_sparsifier, defaults, sparse_config,
8289
sparse_config_actual = copy.deepcopy(sparse_config)
8390
sparse_config_actual.update(sparse_config_layer)
8491

85-
name = module_to_fqn(activation_sparsifier.model, layer_arg['layer'])
92+
name = module_to_fqn(activation_sparsifier.model, layer_arg["layer"])
8693

87-
assert data_groups[name]['sparse_config'] == sparse_config_actual
94+
assert data_groups[name]["sparse_config"] == sparse_config_actual
8895

8996
# assert the rest
9097
other_config_actual = copy.deepcopy(defaults)
9198
other_config_actual.update(layer_arg)
92-
other_config_actual.pop('layer')
99+
other_config_actual.pop("layer")
93100

94101
for key, value in other_config_actual.items():
95102
assert key in data_groups[name]
@@ -119,13 +126,15 @@ def _check_pre_forward_hook(self, activation_sparsifier, data_list):
119126
data_agg_actual = data_list[0]
120127
model = activation_sparsifier.model
121128
layer_name = module_to_fqn(model, model.conv1)
122-
agg_fn = activation_sparsifier.data_groups[layer_name]['aggregate_fn']
129+
agg_fn = activation_sparsifier.data_groups[layer_name]["aggregate_fn"]
123130

124131
for i in range(1, len(data_list)):
125132
data_agg_actual = agg_fn(data_agg_actual, data_list[i])
126133

127-
assert 'data' in activation_sparsifier.data_groups[layer_name]
128-
assert torch.all(activation_sparsifier.data_groups[layer_name]['data'] == data_agg_actual)
134+
assert "data" in activation_sparsifier.data_groups[layer_name]
135+
assert torch.all(
136+
activation_sparsifier.data_groups[layer_name]["data"] == data_agg_actual
137+
)
129138

130139
return data_agg_actual
131140

@@ -144,20 +153,19 @@ def _check_step(self, activation_sparsifier, data_agg_actual):
144153
layer_name = module_to_fqn(model, model.conv1)
145154
assert layer_name is not None
146155

147-
reduce_fn = activation_sparsifier.data_groups[layer_name]['reduce_fn']
156+
reduce_fn = activation_sparsifier.data_groups[layer_name]["reduce_fn"]
148157

149158
data_reduce_actual = reduce_fn(data_agg_actual)
150-
mask_fn = activation_sparsifier.data_groups[layer_name]['mask_fn']
151-
sparse_config = activation_sparsifier.data_groups[layer_name]['sparse_config']
159+
mask_fn = activation_sparsifier.data_groups[layer_name]["mask_fn"]
160+
sparse_config = activation_sparsifier.data_groups[layer_name]["sparse_config"]
152161
mask_actual = mask_fn(data_reduce_actual, **sparse_config)
153162

154163
mask_model = activation_sparsifier.get_mask(layer_name)
155164

156165
assert torch.all(mask_model == mask_actual)
157166

158167
for config in activation_sparsifier.data_groups.values():
159-
assert 'data' not in config
160-
168+
assert "data" not in config
161169

162170
def _check_squash_mask(self, activation_sparsifier, data):
163171
"""Makes sure that squash_mask() works as usual. Specifically, checks
@@ -172,32 +180,41 @@ def _check_squash_mask(self, activation_sparsifier, data):
172180
data (torch tensor)
173181
dummy batched data
174182
"""
183+
175184
# create a forward hook for checking output == layer(input * mask)
176185
def check_output(name):
177186
mask = activation_sparsifier.get_mask(name)
178-
features = activation_sparsifier.data_groups[name].get('features')
179-
feature_dim = activation_sparsifier.data_groups[name].get('feature_dim')
187+
features = activation_sparsifier.data_groups[name].get("features")
188+
feature_dim = activation_sparsifier.data_groups[name].get("feature_dim")
180189

181190
def hook(module, input, output):
182191
input_data = input[0]
183192
if features is None:
184193
assert torch.all(mask * input_data == output)
185194
else:
186195
for feature_idx in range(0, len(features)):
187-
feature = torch.Tensor([features[feature_idx]], device=input_data.device).long()
188-
inp_data_feature = torch.index_select(input_data, feature_dim, feature)
189-
out_data_feature = torch.index_select(output, feature_dim, feature)
196+
feature = torch.Tensor(
197+
[features[feature_idx]], device=input_data.device
198+
).long()
199+
inp_data_feature = torch.index_select(
200+
input_data, feature_dim, feature
201+
)
202+
out_data_feature = torch.index_select(
203+
output, feature_dim, feature
204+
)
205+
206+
assert torch.all(
207+
mask[feature_idx] * inp_data_feature == out_data_feature
208+
)
190209

191-
assert torch.all(mask[feature_idx] * inp_data_feature == out_data_feature)
192210
return hook
193211

194212
for name, config in activation_sparsifier.data_groups.items():
195-
if 'identity' in name:
196-
config['layer'].register_forward_hook(check_output(name))
213+
if "identity" in name:
214+
config["layer"].register_forward_hook(check_output(name))
197215

198216
activation_sparsifier.model(data)
199217

200-
201218
def _check_state_dict(self, sparsifier1):
202219
"""Checks if loading and restoring of state_dict() works as expected.
203220
Basically, dumps the state of the sparsifier and loads it in the other sparsifier
@@ -222,8 +239,8 @@ def _check_state_dict(self, sparsifier1):
222239

223240
for name, state in sparsifier2.state.items():
224241
assert name in sparsifier1.state
225-
mask1 = sparsifier1.state[name]['mask']
226-
mask2 = state['mask']
242+
mask1 = sparsifier1.state[name]["mask"]
243+
mask2 = state["mask"]
227244

228245
if mask1 is None:
229246
assert mask2 is None
@@ -237,8 +254,8 @@ def _check_state_dict(self, sparsifier1):
237254
assert torch.all(mask1 == mask2)
238255

239256
# make sure that the state dict is stored as torch sparse
240-
for state in state_dict['state'].values():
241-
mask = state['mask']
257+
for state in state_dict["state"].values():
258+
mask = state["mask"]
242259
if mask is not None:
243260
if isinstance(mask, List):
244261
for idx in range(len(mask)):
@@ -252,8 +269,16 @@ def _check_state_dict(self, sparsifier1):
252269
assert layer_name in dg2
253270

254271
# exclude hook and layer
255-
config1 = {key: value for key, value in config.items() if key not in ['hook', 'layer']}
256-
config2 = {key: value for key, value in dg2[layer_name].items() if key not in ['hook', 'layer']}
272+
config1 = {
273+
key: value
274+
for key, value in config.items()
275+
if key not in ["hook", "layer"]
276+
}
277+
config2 = {
278+
key: value
279+
for key, value in dg2[layer_name].items()
280+
if key not in ["hook", "layer"]
281+
}
257282

258283
assert config1 == config2
259284

@@ -263,6 +288,7 @@ def test_activation_sparsifier(self):
263288
till squash_mask().
264289
The idea is to check that everything works as expected while in the workflow.
265290
"""
291+
266292
# defining aggregate, reduce and mask functions
267293
def agg_fn(x, y):
268294
return x + y
@@ -287,14 +313,9 @@ def _vanilla_norm_sparsifier(data, sparsity_level):
287313

288314
# Creating default function and sparse configs
289315
# default sparse_config
290-
sparse_config = {
291-
'sparsity_level': 0.5
292-
}
316+
sparse_config = {"sparsity_level": 0.5}
293317

294-
defaults = {
295-
'aggregate_fn': agg_fn,
296-
'reduce_fn': reduce_fn
297-
}
318+
defaults = {"aggregate_fn": agg_fn, "reduce_fn": reduce_fn}
298319

299320
# simulate the workflow
300321
# STEP 1: make data and activation sparsifier object
@@ -306,43 +327,51 @@ def _vanilla_norm_sparsifier(data, sparsity_level):
306327

307328
# STEP 2: Register some layers
308329
register_layer1_args = {
309-
'layer': model.conv1,
310-
'mask_fn': _vanilla_norm_sparsifier
330+
"layer": model.conv1,
331+
"mask_fn": _vanilla_norm_sparsifier,
311332
}
312-
sparse_config_layer1 = {'sparsity_level': 0.3}
333+
sparse_config_layer1 = {"sparsity_level": 0.3}
313334

314335
register_layer2_args = {
315-
'layer': model.linear1,
316-
'features': [0, 10, 234],
317-
'feature_dim': 1,
318-
'mask_fn': _vanilla_norm_sparsifier
336+
"layer": model.linear1,
337+
"features": [0, 10, 234],
338+
"feature_dim": 1,
339+
"mask_fn": _vanilla_norm_sparsifier,
319340
}
320-
sparse_config_layer2 = {'sparsity_level': 0.1}
341+
sparse_config_layer2 = {"sparsity_level": 0.1}
321342

322343
register_layer3_args = {
323-
'layer': model.identity1,
324-
'mask_fn': _vanilla_norm_sparsifier
344+
"layer": model.identity1,
345+
"mask_fn": _vanilla_norm_sparsifier,
325346
}
326-
sparse_config_layer3 = {'sparsity_level': 0.3}
347+
sparse_config_layer3 = {"sparsity_level": 0.3}
327348

328349
register_layer4_args = {
329-
'layer': model.identity2,
330-
'features': [0, 10, 20],
331-
'feature_dim': 1,
332-
'mask_fn': _vanilla_norm_sparsifier
350+
"layer": model.identity2,
351+
"features": [0, 10, 20],
352+
"feature_dim": 1,
353+
"mask_fn": _vanilla_norm_sparsifier,
333354
}
334-
sparse_config_layer4 = {'sparsity_level': 0.1}
355+
sparse_config_layer4 = {"sparsity_level": 0.1}
335356

336-
layer_args_list = [(register_layer1_args, sparse_config_layer1), (register_layer2_args, sparse_config_layer2)]
337-
layer_args_list += [(register_layer3_args, sparse_config_layer3), (register_layer4_args, sparse_config_layer4)]
357+
layer_args_list = [
358+
(register_layer1_args, sparse_config_layer1),
359+
(register_layer2_args, sparse_config_layer2),
360+
]
361+
layer_args_list += [
362+
(register_layer3_args, sparse_config_layer3),
363+
(register_layer4_args, sparse_config_layer4),
364+
]
338365

339366
# Registering..
340367
for layer_args in layer_args_list:
341368
layer_arg, sparse_config_layer = layer_args
342369
activation_sparsifier.register_layer(**layer_arg, **sparse_config_layer)
343370

344371
# check if things are registered correctly
345-
self._check_register_layer(activation_sparsifier, defaults, sparse_config, layer_args_list)
372+
self._check_register_layer(
373+
activation_sparsifier, defaults, sparse_config, layer_args_list
374+
)
346375

347376
# check state_dict after registering and before model forward
348377
self._check_state_dict(activation_sparsifier)

0 commit comments

Comments
 (0)