1
1
# Owner(s): ["module: unknown"]
2
2
3
3
import copy
4
- from torch .testing ._internal .common_utils import TestCase , skipIfTorchDynamo
5
4
import logging
5
+ from typing import List
6
+
6
7
import torch
7
- from torch .ao .pruning ._experimental .activation_sparsifier .activation_sparsifier import ActivationSparsifier
8
8
import torch .nn as nn
9
9
import torch .nn .functional as F
10
+ from torch .ao .pruning ._experimental .activation_sparsifier .activation_sparsifier import (
11
+ ActivationSparsifier ,
12
+ )
10
13
from torch .ao .pruning .sparsifier .utils import module_to_fqn
11
- from typing import List
14
+ from torch . testing . _internal . common_utils import skipIfTorchDynamo , TestCase
12
15
13
- logging .basicConfig (format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' , level = logging .INFO )
16
+ logging .basicConfig (
17
+ format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" , level = logging .INFO
18
+ )
14
19
15
20
16
21
class Model (nn .Module ):
@@ -45,7 +50,7 @@ def _check_constructor(self, activation_sparsifier, model, defaults, sparse_conf
45
50
in the activation sparsifier
46
51
"""
47
52
sparsifier_defaults = activation_sparsifier .defaults
48
- combined_defaults = {** defaults , ' sparse_config' : sparse_config }
53
+ combined_defaults = {** defaults , " sparse_config" : sparse_config }
49
54
50
55
# more keys are populated in activation sparsifier (eventhough they may be None)
51
56
assert len (combined_defaults ) <= len (activation_sparsifier .defaults )
@@ -54,7 +59,9 @@ def _check_constructor(self, activation_sparsifier, model, defaults, sparse_conf
54
59
# all the keys in combined_defaults should be present in sparsifier defaults
55
60
assert config == combined_defaults .get (key , None )
56
61
57
- def _check_register_layer (self , activation_sparsifier , defaults , sparse_config , layer_args_list ):
62
+ def _check_register_layer (
63
+ self , activation_sparsifier , defaults , sparse_config , layer_args_list
64
+ ):
58
65
"""Checks if layers in the model are correctly mapped to it's arguments.
59
66
60
67
Args:
@@ -82,14 +89,14 @@ def _check_register_layer(self, activation_sparsifier, defaults, sparse_config,
82
89
sparse_config_actual = copy .deepcopy (sparse_config )
83
90
sparse_config_actual .update (sparse_config_layer )
84
91
85
- name = module_to_fqn (activation_sparsifier .model , layer_arg [' layer' ])
92
+ name = module_to_fqn (activation_sparsifier .model , layer_arg [" layer" ])
86
93
87
- assert data_groups [name ][' sparse_config' ] == sparse_config_actual
94
+ assert data_groups [name ][" sparse_config" ] == sparse_config_actual
88
95
89
96
# assert the rest
90
97
other_config_actual = copy .deepcopy (defaults )
91
98
other_config_actual .update (layer_arg )
92
- other_config_actual .pop (' layer' )
99
+ other_config_actual .pop (" layer" )
93
100
94
101
for key , value in other_config_actual .items ():
95
102
assert key in data_groups [name ]
@@ -119,13 +126,15 @@ def _check_pre_forward_hook(self, activation_sparsifier, data_list):
119
126
data_agg_actual = data_list [0 ]
120
127
model = activation_sparsifier .model
121
128
layer_name = module_to_fqn (model , model .conv1 )
122
- agg_fn = activation_sparsifier .data_groups [layer_name ][' aggregate_fn' ]
129
+ agg_fn = activation_sparsifier .data_groups [layer_name ][" aggregate_fn" ]
123
130
124
131
for i in range (1 , len (data_list )):
125
132
data_agg_actual = agg_fn (data_agg_actual , data_list [i ])
126
133
127
- assert 'data' in activation_sparsifier .data_groups [layer_name ]
128
- assert torch .all (activation_sparsifier .data_groups [layer_name ]['data' ] == data_agg_actual )
134
+ assert "data" in activation_sparsifier .data_groups [layer_name ]
135
+ assert torch .all (
136
+ activation_sparsifier .data_groups [layer_name ]["data" ] == data_agg_actual
137
+ )
129
138
130
139
return data_agg_actual
131
140
@@ -144,20 +153,19 @@ def _check_step(self, activation_sparsifier, data_agg_actual):
144
153
layer_name = module_to_fqn (model , model .conv1 )
145
154
assert layer_name is not None
146
155
147
- reduce_fn = activation_sparsifier .data_groups [layer_name ][' reduce_fn' ]
156
+ reduce_fn = activation_sparsifier .data_groups [layer_name ][" reduce_fn" ]
148
157
149
158
data_reduce_actual = reduce_fn (data_agg_actual )
150
- mask_fn = activation_sparsifier .data_groups [layer_name ][' mask_fn' ]
151
- sparse_config = activation_sparsifier .data_groups [layer_name ][' sparse_config' ]
159
+ mask_fn = activation_sparsifier .data_groups [layer_name ][" mask_fn" ]
160
+ sparse_config = activation_sparsifier .data_groups [layer_name ][" sparse_config" ]
152
161
mask_actual = mask_fn (data_reduce_actual , ** sparse_config )
153
162
154
163
mask_model = activation_sparsifier .get_mask (layer_name )
155
164
156
165
assert torch .all (mask_model == mask_actual )
157
166
158
167
for config in activation_sparsifier .data_groups .values ():
159
- assert 'data' not in config
160
-
168
+ assert "data" not in config
161
169
162
170
def _check_squash_mask (self , activation_sparsifier , data ):
163
171
"""Makes sure that squash_mask() works as usual. Specifically, checks
@@ -172,32 +180,41 @@ def _check_squash_mask(self, activation_sparsifier, data):
172
180
data (torch tensor)
173
181
dummy batched data
174
182
"""
183
+
175
184
# create a forward hook for checking output == layer(input * mask)
176
185
def check_output (name ):
177
186
mask = activation_sparsifier .get_mask (name )
178
- features = activation_sparsifier .data_groups [name ].get (' features' )
179
- feature_dim = activation_sparsifier .data_groups [name ].get (' feature_dim' )
187
+ features = activation_sparsifier .data_groups [name ].get (" features" )
188
+ feature_dim = activation_sparsifier .data_groups [name ].get (" feature_dim" )
180
189
181
190
def hook (module , input , output ):
182
191
input_data = input [0 ]
183
192
if features is None :
184
193
assert torch .all (mask * input_data == output )
185
194
else :
186
195
for feature_idx in range (0 , len (features )):
187
- feature = torch .Tensor ([features [feature_idx ]], device = input_data .device ).long ()
188
- inp_data_feature = torch .index_select (input_data , feature_dim , feature )
189
- out_data_feature = torch .index_select (output , feature_dim , feature )
196
+ feature = torch .Tensor (
197
+ [features [feature_idx ]], device = input_data .device
198
+ ).long ()
199
+ inp_data_feature = torch .index_select (
200
+ input_data , feature_dim , feature
201
+ )
202
+ out_data_feature = torch .index_select (
203
+ output , feature_dim , feature
204
+ )
205
+
206
+ assert torch .all (
207
+ mask [feature_idx ] * inp_data_feature == out_data_feature
208
+ )
190
209
191
- assert torch .all (mask [feature_idx ] * inp_data_feature == out_data_feature )
192
210
return hook
193
211
194
212
for name , config in activation_sparsifier .data_groups .items ():
195
- if ' identity' in name :
196
- config [' layer' ].register_forward_hook (check_output (name ))
213
+ if " identity" in name :
214
+ config [" layer" ].register_forward_hook (check_output (name ))
197
215
198
216
activation_sparsifier .model (data )
199
217
200
-
201
218
def _check_state_dict (self , sparsifier1 ):
202
219
"""Checks if loading and restoring of state_dict() works as expected.
203
220
Basically, dumps the state of the sparsifier and loads it in the other sparsifier
@@ -222,8 +239,8 @@ def _check_state_dict(self, sparsifier1):
222
239
223
240
for name , state in sparsifier2 .state .items ():
224
241
assert name in sparsifier1 .state
225
- mask1 = sparsifier1 .state [name ][' mask' ]
226
- mask2 = state [' mask' ]
242
+ mask1 = sparsifier1 .state [name ][" mask" ]
243
+ mask2 = state [" mask" ]
227
244
228
245
if mask1 is None :
229
246
assert mask2 is None
@@ -237,8 +254,8 @@ def _check_state_dict(self, sparsifier1):
237
254
assert torch .all (mask1 == mask2 )
238
255
239
256
# make sure that the state dict is stored as torch sparse
240
- for state in state_dict [' state' ].values ():
241
- mask = state [' mask' ]
257
+ for state in state_dict [" state" ].values ():
258
+ mask = state [" mask" ]
242
259
if mask is not None :
243
260
if isinstance (mask , List ):
244
261
for idx in range (len (mask )):
@@ -252,8 +269,16 @@ def _check_state_dict(self, sparsifier1):
252
269
assert layer_name in dg2
253
270
254
271
# exclude hook and layer
255
- config1 = {key : value for key , value in config .items () if key not in ['hook' , 'layer' ]}
256
- config2 = {key : value for key , value in dg2 [layer_name ].items () if key not in ['hook' , 'layer' ]}
272
+ config1 = {
273
+ key : value
274
+ for key , value in config .items ()
275
+ if key not in ["hook" , "layer" ]
276
+ }
277
+ config2 = {
278
+ key : value
279
+ for key , value in dg2 [layer_name ].items ()
280
+ if key not in ["hook" , "layer" ]
281
+ }
257
282
258
283
assert config1 == config2
259
284
@@ -263,6 +288,7 @@ def test_activation_sparsifier(self):
263
288
till squash_mask().
264
289
The idea is to check that everything works as expected while in the workflow.
265
290
"""
291
+
266
292
# defining aggregate, reduce and mask functions
267
293
def agg_fn (x , y ):
268
294
return x + y
@@ -287,14 +313,9 @@ def _vanilla_norm_sparsifier(data, sparsity_level):
287
313
288
314
# Creating default function and sparse configs
289
315
# default sparse_config
290
- sparse_config = {
291
- 'sparsity_level' : 0.5
292
- }
316
+ sparse_config = {"sparsity_level" : 0.5 }
293
317
294
- defaults = {
295
- 'aggregate_fn' : agg_fn ,
296
- 'reduce_fn' : reduce_fn
297
- }
318
+ defaults = {"aggregate_fn" : agg_fn , "reduce_fn" : reduce_fn }
298
319
299
320
# simulate the workflow
300
321
# STEP 1: make data and activation sparsifier object
@@ -306,43 +327,51 @@ def _vanilla_norm_sparsifier(data, sparsity_level):
306
327
307
328
# STEP 2: Register some layers
308
329
register_layer1_args = {
309
- ' layer' : model .conv1 ,
310
- ' mask_fn' : _vanilla_norm_sparsifier
330
+ " layer" : model .conv1 ,
331
+ " mask_fn" : _vanilla_norm_sparsifier ,
311
332
}
312
- sparse_config_layer1 = {' sparsity_level' : 0.3 }
333
+ sparse_config_layer1 = {" sparsity_level" : 0.3 }
313
334
314
335
register_layer2_args = {
315
- ' layer' : model .linear1 ,
316
- ' features' : [0 , 10 , 234 ],
317
- ' feature_dim' : 1 ,
318
- ' mask_fn' : _vanilla_norm_sparsifier
336
+ " layer" : model .linear1 ,
337
+ " features" : [0 , 10 , 234 ],
338
+ " feature_dim" : 1 ,
339
+ " mask_fn" : _vanilla_norm_sparsifier ,
319
340
}
320
- sparse_config_layer2 = {' sparsity_level' : 0.1 }
341
+ sparse_config_layer2 = {" sparsity_level" : 0.1 }
321
342
322
343
register_layer3_args = {
323
- ' layer' : model .identity1 ,
324
- ' mask_fn' : _vanilla_norm_sparsifier
344
+ " layer" : model .identity1 ,
345
+ " mask_fn" : _vanilla_norm_sparsifier ,
325
346
}
326
- sparse_config_layer3 = {' sparsity_level' : 0.3 }
347
+ sparse_config_layer3 = {" sparsity_level" : 0.3 }
327
348
328
349
register_layer4_args = {
329
- ' layer' : model .identity2 ,
330
- ' features' : [0 , 10 , 20 ],
331
- ' feature_dim' : 1 ,
332
- ' mask_fn' : _vanilla_norm_sparsifier
350
+ " layer" : model .identity2 ,
351
+ " features" : [0 , 10 , 20 ],
352
+ " feature_dim" : 1 ,
353
+ " mask_fn" : _vanilla_norm_sparsifier ,
333
354
}
334
- sparse_config_layer4 = {' sparsity_level' : 0.1 }
355
+ sparse_config_layer4 = {" sparsity_level" : 0.1 }
335
356
336
- layer_args_list = [(register_layer1_args , sparse_config_layer1 ), (register_layer2_args , sparse_config_layer2 )]
337
- layer_args_list += [(register_layer3_args , sparse_config_layer3 ), (register_layer4_args , sparse_config_layer4 )]
357
+ layer_args_list = [
358
+ (register_layer1_args , sparse_config_layer1 ),
359
+ (register_layer2_args , sparse_config_layer2 ),
360
+ ]
361
+ layer_args_list += [
362
+ (register_layer3_args , sparse_config_layer3 ),
363
+ (register_layer4_args , sparse_config_layer4 ),
364
+ ]
338
365
339
366
# Registering..
340
367
for layer_args in layer_args_list :
341
368
layer_arg , sparse_config_layer = layer_args
342
369
activation_sparsifier .register_layer (** layer_arg , ** sparse_config_layer )
343
370
344
371
# check if things are registered correctly
345
- self ._check_register_layer (activation_sparsifier , defaults , sparse_config , layer_args_list )
372
+ self ._check_register_layer (
373
+ activation_sparsifier , defaults , sparse_config , layer_args_list
374
+ )
346
375
347
376
# check state_dict after registering and before model forward
348
377
self ._check_state_dict (activation_sparsifier )
0 commit comments