@@ -53,9 +53,7 @@ def __init__(
53
53
), # For fixed input size models like ViT
54
54
** kwargs ,
55
55
):
56
- super ().__init__ (
57
- model = model , preprocess_fn = preprocess_fn , device_name = device_name
58
- )
56
+ super ().__init__ (model = model , preprocess_fn = preprocess_fn , device_name = device_name )
59
57
self ._target_layer = target_layer
60
58
self ._embed_scaling = embed_scaling
61
59
self ._input_size = input_size
@@ -77,9 +75,7 @@ def prepare_model(self, load_model: bool = True) -> torch.nn.Module:
77
75
78
76
# Feature
79
77
if self ._target_layer :
80
- feature_module = self ._find_feature_module_by_name (
81
- model , self ._target_layer
82
- )
78
+ feature_module = self ._find_feature_module_by_name (model , self ._target_layer )
83
79
else :
84
80
feature_module = self ._find_feature_module_auto (model )
85
81
feature_module .register_forward_hook (self ._feature_hook )
@@ -112,9 +108,7 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping:
112
108
output [name ] = data .numpy (force = True )
113
109
return output
114
110
115
- def _find_feature_module_by_name (
116
- self , model : torch .nn .Module , target_name : str
117
- ) -> torch .nn .Module :
111
+ def _find_feature_module_by_name (self , model : torch .nn .Module , target_name : str ) -> torch .nn .Module :
118
112
"""Search the last layer by name sub string match."""
119
113
target_module = None
120
114
for name , module in model .named_modules ():
@@ -135,9 +129,7 @@ def _has_spatial_dim(shape: torch.Size):
135
129
return False
136
130
if shape [2 ] <= 1 or shape [3 ] <= 1 : # H > 1 and W > 1
137
131
return False
138
- if (
139
- shape [1 ] <= shape [2 ] or shape [1 ] <= shape [3 ]
140
- ): # H < C and H < C for feature maps generally
132
+ if shape [1 ] <= shape [2 ] or shape [1 ] <= shape [3 ]: # H < C and H < C for feature maps generally
141
133
return False
142
134
return True
143
135
@@ -149,36 +141,26 @@ def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None:
149
141
if _has_spatial_dim (shape ):
150
142
self ._feature_module = module
151
143
152
- global_hook_handle = torch .nn .modules .module .register_module_forward_hook (
153
- _detect_hook
154
- )
144
+ global_hook_handle = torch .nn .modules .module .register_module_forward_hook (_detect_hook )
155
145
try :
156
146
module .forward (torch .zeros ((1 , 3 , * self ._input_size )))
157
147
finally :
158
148
global_hook_handle .remove ()
159
149
if self ._feature_module is None :
160
- raise RuntimeError (
161
- "Feature module with 4D output is not found in the torch model"
162
- )
163
- if (
164
- self ._feature_module .index / self ._num_modules < 0.5
165
- ): # Check if ViT-like architectures
150
+ raise RuntimeError ("Feature module with 4D output is not found in the torch model" )
151
+ if self ._feature_module .index / self ._num_modules < 0.5 : # Check if ViT-like architectures
166
152
raise RuntimeError (
167
153
f"Modules with 4D output end in early-half stages: { 100 * self ._feature_module .index / self ._num_modules } %"
168
154
)
169
155
170
156
return self ._feature_module
171
157
172
- def _feature_hook (
173
- self , module : torch .nn .Module , inputs : Any , output : torch .Tensor
174
- ) -> torch .Tensor :
158
+ def _feature_hook (self , module : torch .nn .Module , inputs : Any , output : torch .Tensor ) -> torch .Tensor :
175
159
"""Manipulate feature map for saliency map generation."""
176
160
self ._feature_map = output
177
161
return output
178
162
179
- def _output_hook (
180
- self , module : torch .nn .Module , inputs : Any , output : torch .Tensor
181
- ) -> Dict [str , torch .Tensor ]:
163
+ def _output_hook (self , module : torch .nn .Module , inputs : Any , output : torch .Tensor ) -> Dict [str , torch .Tensor ]:
182
164
"""Split combined output B0xC into BxC precition and BxCxHxW saliency map."""
183
165
return {
184
166
"prediction" : output ,
@@ -195,18 +177,14 @@ def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor:
195
177
"""Normalize saliency maps."""
196
178
max_values = saliency_map .max (dim = - 1 , keepdim = True ).values
197
179
min_values = saliency_map .min (dim = - 1 , keepdim = True ).values
198
- saliency_map = (
199
- 255 * (saliency_map - min_values ) / (max_values - min_values + 1e-12 )
200
- )
180
+ saliency_map = 255 * (saliency_map - min_values ) / (max_values - min_values + 1e-12 )
201
181
return saliency_map .to (torch .uint8 )
202
182
203
183
204
184
class TorchActivationMap (TorchWhiteBoxMethod ):
205
185
"""ActivationMap. Mean of the feature map along the channel dimension."""
206
186
207
- def _output_hook (
208
- self , module : torch .nn .Module , inputs : Any , output : torch .Tensor
209
- ) -> Dict [str , torch .Tensor ]:
187
+ def _output_hook (self , module : torch .nn .Module , inputs : Any , output : torch .Tensor ) -> Dict [str , torch .Tensor ]:
210
188
feature_map = self ._feature_map
211
189
batch_size , _ , h , w = feature_map .shape
212
190
activation_map = torch .mean (feature_map , dim = 1 )
@@ -233,56 +211,42 @@ def __init__(self, *args, optimize_gap: bool = False, **kwargs):
233
211
self ._optimize_gap = optimize_gap
234
212
super ().__init__ (* args , ** kwargs )
235
213
236
- def _feature_hook (
237
- self , module : torch .nn .Module , inputs : Any , output : torch .Tensor
238
- ) -> torch .Tensor :
214
+ def _feature_hook (self , module : torch .nn .Module , inputs : Any , output : torch .Tensor ) -> torch .Tensor :
239
215
"""feature_maps -> vertical stack of feature_maps + mosaic_feature_maps."""
240
216
batch_size , c , h , w = self ._feature_shape = output .shape
241
217
feature_map = output
242
218
if self ._optimize_gap :
243
- feature_map = feature_map .reshape ([batch_size , c , h * w ]).mean (dim = - 1 )[
244
- :, :, None , None
245
- ] # Spatial average
219
+ feature_map = feature_map .reshape ([batch_size , c , h * w ]).mean (dim = - 1 )[:, :, None , None ] # Spatial average
246
220
feature_maps = [feature_map ]
247
221
for i in range (batch_size ):
248
222
mosaic_feature_map = self ._get_mosaic_feature_map (output [i ], c , h , w )
249
223
feature_maps .append (mosaic_feature_map )
250
224
return torch .cat (feature_maps )
251
225
252
- def _output_hook (
253
- self , module : torch .nn .Module , inputs : Any , output : torch .Tensor
254
- ) -> Dict [str , torch .Tensor ]:
226
+ def _output_hook (self , module : torch .nn .Module , inputs : Any , output : torch .Tensor ) -> Dict [str , torch .Tensor ]:
255
227
"""Split combined output B0xC into BxC precition and BxCxHxW saliency map."""
256
228
batch_size , _ , h , w = self ._feature_shape # B0xDxHxW
257
229
num_classes = output .shape [1 ] # C
258
230
predictions = output [:batch_size ] # BxC
259
231
saliency_maps = output [batch_size :] # BHWxC
260
- saliency_maps = saliency_maps .reshape (
261
- [batch_size , h * w , num_classes ]
262
- ) # BxHWxC
232
+ saliency_maps = saliency_maps .reshape ([batch_size , h * w , num_classes ]) # BxHWxC
263
233
saliency_maps = saliency_maps .transpose (1 , 2 ) # BxCxHW
264
234
if self ._embed_scaling :
265
235
saliency_maps = saliency_maps .reshape ((batch_size * num_classes , h * w ))
266
236
saliency_maps = self ._normalize_map (saliency_maps )
267
- saliency_maps = saliency_maps .reshape (
268
- [batch_size , num_classes , h , w ]
269
- ) # BxCxHxW
237
+ saliency_maps = saliency_maps .reshape ([batch_size , num_classes , h , w ]) # BxCxHxW
270
238
return {
271
239
"prediction" : predictions ,
272
240
SALIENCY_MAP_OUTPUT_NAME : saliency_maps ,
273
241
}
274
242
275
- def _get_mosaic_feature_map (
276
- self , feature_map : torch .Tensor , c : int , h : int , w : int
277
- ) -> torch .Tensor :
243
+ def _get_mosaic_feature_map (self , feature_map : torch .Tensor , c : int , h : int , w : int ) -> torch .Tensor :
278
244
if self ._optimize_gap :
279
245
# if isinstance(model_neck, GlobalAveragePooling):
280
246
# Optimization workaround for the GAP case (simulate GAP with more simple compute graph)
281
247
# Possible due to static sparsity of mosaic_feature_map
282
248
# Makes the downstream GAP operation to be dummy
283
- feature_map_transposed = torch .flatten (feature_map , start_dim = 1 ).transpose (
284
- 0 , 1
285
- )[:, :, None , None ]
249
+ feature_map_transposed = torch .flatten (feature_map , start_dim = 1 ).transpose (0 , 1 )[:, :, None , None ]
286
250
mosaic_feature_map = feature_map_transposed / (h * w )
287
251
else :
288
252
feature_map_repeated = feature_map .repeat (h * w , 1 , 1 , 1 )
@@ -291,9 +255,7 @@ def _get_mosaic_feature_map(
291
255
for i in range (h ):
292
256
for j in range (w ):
293
257
k = spatial_order [i , j ]
294
- mosaic_feature_map_mask [k , :, i , j ] = torch .ones (c ).to (
295
- feature_map .device
296
- )
258
+ mosaic_feature_map_mask [k , :, i , j ] = torch .ones (c ).to (feature_map .device )
297
259
mosaic_feature_map = feature_map_repeated * mosaic_feature_map_mask
298
260
return mosaic_feature_map
299
261
@@ -328,24 +290,16 @@ def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module:
328
290
self ._feature_module = None
329
291
norm_modules = []
330
292
for name , sub_module in module .named_modules ():
331
- if (
332
- "LayerNorm" in type (sub_module ).__name__
333
- or "BatchNorm" in type (sub_module ).__name__
334
- or "norm1" in name
335
- ):
293
+ if "LayerNorm" in type (sub_module ).__name__ or "BatchNorm" in type (sub_module ).__name__ or "norm1" in name :
336
294
norm_modules .append (sub_module )
337
295
338
296
if len (norm_modules ) < 3 :
339
- raise RuntimeError (
340
- "Feature modules with LayerNorm or BatchNorm are less than 3 in the torch model"
341
- )
297
+ raise RuntimeError ("Feature modules with LayerNorm or BatchNorm are less than 3 in the torch model" )
342
298
343
299
self ._feature_module = norm_modules [- 3 ]
344
300
return self ._feature_module
345
301
346
- def _feature_hook (
347
- self , module : torch .nn .Module , inputs : Any , output : torch .Tensor
348
- ) -> torch .Tensor :
302
+ def _feature_hook (self , module : torch .nn .Module , inputs : Any , output : torch .Tensor ) -> torch .Tensor :
349
303
"""feature_maps -> vertical stack of feature_maps + mosaic_feature_maps."""
350
304
feature_map = output
351
305
batch_size , num_tokens , dim = feature_map .shape
@@ -357,19 +311,15 @@ def _feature_hook(
357
311
feature_maps .append (mosaic_feature_map )
358
312
return torch .cat (feature_maps )
359
313
360
- def _get_mosaic_feature_map (
361
- self , feature_map : torch .Tensor , c : int , h : int , w : int
362
- ) -> torch .Tensor :
314
+ def _get_mosaic_feature_map (self , feature_map : torch .Tensor , c : int , h : int , w : int ) -> torch .Tensor :
363
315
num_tokens = h * w + 1
364
316
mosaic_feature_map = torch .zeros (h * w , num_tokens , c ).to (feature_map .device )
365
317
366
318
if self ._use_gaussian :
367
319
if self ._use_cls_token :
368
320
mosaic_feature_map [:, 0 , :] = feature_map [0 , :]
369
321
feature_map_spatial = feature_map [1 :, :].reshape (1 , h , w , c )
370
- feature_map_spatial_repeated = feature_map_spatial .repeat (
371
- h * w , 1 , 1 , 1
372
- ) # 196, 14, 14, 192
322
+ feature_map_spatial_repeated = feature_map_spatial .repeat (h * w , 1 , 1 , 1 ) # 196, 14, 14, 192
373
323
374
324
spatial_order = torch .arange (h * w ).reshape (h , w )
375
325
gaussian = torch .tensor (
@@ -379,40 +329,26 @@ def _get_mosaic_feature_map(
379
329
[1 / 16.0 , 1 / 8.0 , 1 / 16.0 ],
380
330
],
381
331
).to (feature_map .device )
382
- mosaic_feature_map_mask_padded = torch .zeros (h * w , h + 2 , w + 2 ).to (
383
- feature_map .device
384
- )
332
+ mosaic_feature_map_mask_padded = torch .zeros (h * w , h + 2 , w + 2 ).to (feature_map .device )
385
333
for i in range (h ):
386
334
for j in range (w ):
387
335
k = spatial_order [i , j ]
388
336
i_pad = i + 1
389
337
j_pad = j + 1
390
- mosaic_feature_map_mask_padded [
391
- k , i_pad - 1 : i_pad + 2 , j_pad - 1 : j_pad + 2
392
- ] = gaussian
338
+ mosaic_feature_map_mask_padded [k , i_pad - 1 : i_pad + 2 , j_pad - 1 : j_pad + 2 ] = gaussian
393
339
mosaic_feature_map_mask = mosaic_feature_map_mask_padded [:, 1 :- 1 , 1 :- 1 ]
394
- mosaic_feature_map_mask = mosaic_feature_map_mask .unsqueeze (3 ).repeat (
395
- 1 , 1 , 1 , c
396
- )
340
+ mosaic_feature_map_mask = mosaic_feature_map_mask .unsqueeze (3 ).repeat (1 , 1 , 1 , c )
397
341
398
- mosaic_fm_wo_cls_token = (
399
- feature_map_spatial_repeated * mosaic_feature_map_mask
400
- )
401
- mosaic_feature_map [:, 1 :, :] = mosaic_fm_wo_cls_token .reshape (
402
- h * w , h * w , c
403
- )
342
+ mosaic_fm_wo_cls_token = feature_map_spatial_repeated * mosaic_feature_map_mask
343
+ mosaic_feature_map [:, 1 :, :] = mosaic_fm_wo_cls_token .reshape (h * w , h * w , c )
404
344
else :
405
345
feature_map_repeated = feature_map .unsqueeze (0 ).repeat (h * w , 1 , 1 )
406
- mosaic_feature_map_mask = torch .zeros (h * w , num_tokens ).to (
407
- feature_map .device
408
- )
346
+ mosaic_feature_map_mask = torch .zeros (h * w , num_tokens ).to (feature_map .device )
409
347
for i in range (h * w ):
410
348
mosaic_feature_map_mask [i , i + 1 ] = torch .ones (1 ).to (feature_map .device )
411
349
if self ._use_cls_token :
412
350
mosaic_feature_map_mask [:, 0 ] = torch .ones (1 ).to (feature_map .device )
413
- mosaic_feature_map_mask = mosaic_feature_map_mask .unsqueeze (2 ).repeat (
414
- 1 , 1 , c
415
- )
351
+ mosaic_feature_map_mask = mosaic_feature_map_mask .unsqueeze (2 ).repeat (1 , 1 , c )
416
352
mosaic_feature_map = feature_map_repeated * mosaic_feature_map_mask
417
353
418
354
return mosaic_feature_map
0 commit comments