Skip to content

Commit 37c7715

Browse files
committed
formatting
1 parent f4f67fb commit 37c7715

File tree

2 files changed

+34
-104
lines changed

2 files changed

+34
-104
lines changed

openvino_xai/methods/white_box/torch.py

+31-95
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ def __init__(
5353
), # For fixed input size models like ViT
5454
**kwargs,
5555
):
56-
super().__init__(
57-
model=model, preprocess_fn=preprocess_fn, device_name=device_name
58-
)
56+
super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name)
5957
self._target_layer = target_layer
6058
self._embed_scaling = embed_scaling
6159
self._input_size = input_size
@@ -77,9 +75,7 @@ def prepare_model(self, load_model: bool = True) -> torch.nn.Module:
7775

7876
# Feature
7977
if self._target_layer:
80-
feature_module = self._find_feature_module_by_name(
81-
model, self._target_layer
82-
)
78+
feature_module = self._find_feature_module_by_name(model, self._target_layer)
8379
else:
8480
feature_module = self._find_feature_module_auto(model)
8581
feature_module.register_forward_hook(self._feature_hook)
@@ -112,9 +108,7 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping:
112108
output[name] = data.numpy(force=True)
113109
return output
114110

115-
def _find_feature_module_by_name(
116-
self, model: torch.nn.Module, target_name: str
117-
) -> torch.nn.Module:
111+
def _find_feature_module_by_name(self, model: torch.nn.Module, target_name: str) -> torch.nn.Module:
118112
"""Search the last layer by name sub string match."""
119113
target_module = None
120114
for name, module in model.named_modules():
@@ -135,9 +129,7 @@ def _has_spatial_dim(shape: torch.Size):
135129
return False
136130
if shape[2] <= 1 or shape[3] <= 1: # H > 1 and W > 1
137131
return False
138-
if (
139-
shape[1] <= shape[2] or shape[1] <= shape[3]
140-
): # H < C and H < C for feature maps generally
132+
if shape[1] <= shape[2] or shape[1] <= shape[3]: # H < C and H < C for feature maps generally
141133
return False
142134
return True
143135

@@ -149,36 +141,26 @@ def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None:
149141
if _has_spatial_dim(shape):
150142
self._feature_module = module
151143

152-
global_hook_handle = torch.nn.modules.module.register_module_forward_hook(
153-
_detect_hook
154-
)
144+
global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook)
155145
try:
156146
module.forward(torch.zeros((1, 3, *self._input_size)))
157147
finally:
158148
global_hook_handle.remove()
159149
if self._feature_module is None:
160-
raise RuntimeError(
161-
"Feature module with 4D output is not found in the torch model"
162-
)
163-
if (
164-
self._feature_module.index / self._num_modules < 0.5
165-
): # Check if ViT-like architectures
150+
raise RuntimeError("Feature module with 4D output is not found in the torch model")
151+
if self._feature_module.index / self._num_modules < 0.5: # Check if ViT-like architectures
166152
raise RuntimeError(
167153
f"Modules with 4D output end in early-half stages: {100 * self._feature_module.index / self._num_modules}%"
168154
)
169155

170156
return self._feature_module
171157

172-
def _feature_hook(
173-
self, module: torch.nn.Module, inputs: Any, output: torch.Tensor
174-
) -> torch.Tensor:
158+
def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
175159
"""Manipulate feature map for saliency map generation."""
176160
self._feature_map = output
177161
return output
178162

179-
def _output_hook(
180-
self, module: torch.nn.Module, inputs: Any, output: torch.Tensor
181-
) -> Dict[str, torch.Tensor]:
163+
def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
182164
"""Split combined output B0xC into BxC precition and BxCxHxW saliency map."""
183165
return {
184166
"prediction": output,
@@ -195,18 +177,14 @@ def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor:
195177
"""Normalize saliency maps."""
196178
max_values = saliency_map.max(dim=-1, keepdim=True).values
197179
min_values = saliency_map.min(dim=-1, keepdim=True).values
198-
saliency_map = (
199-
255 * (saliency_map - min_values) / (max_values - min_values + 1e-12)
200-
)
180+
saliency_map = 255 * (saliency_map - min_values) / (max_values - min_values + 1e-12)
201181
return saliency_map.to(torch.uint8)
202182

203183

204184
class TorchActivationMap(TorchWhiteBoxMethod):
205185
"""ActivationMap. Mean of the feature map along the channel dimension."""
206186

207-
def _output_hook(
208-
self, module: torch.nn.Module, inputs: Any, output: torch.Tensor
209-
) -> Dict[str, torch.Tensor]:
187+
def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
210188
feature_map = self._feature_map
211189
batch_size, _, h, w = feature_map.shape
212190
activation_map = torch.mean(feature_map, dim=1)
@@ -233,56 +211,42 @@ def __init__(self, *args, optimize_gap: bool = False, **kwargs):
233211
self._optimize_gap = optimize_gap
234212
super().__init__(*args, **kwargs)
235213

236-
def _feature_hook(
237-
self, module: torch.nn.Module, inputs: Any, output: torch.Tensor
238-
) -> torch.Tensor:
214+
def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
239215
"""feature_maps -> vertical stack of feature_maps + mosaic_feature_maps."""
240216
batch_size, c, h, w = self._feature_shape = output.shape
241217
feature_map = output
242218
if self._optimize_gap:
243-
feature_map = feature_map.reshape([batch_size, c, h * w]).mean(dim=-1)[
244-
:, :, None, None
245-
] # Spatial average
219+
feature_map = feature_map.reshape([batch_size, c, h * w]).mean(dim=-1)[:, :, None, None] # Spatial average
246220
feature_maps = [feature_map]
247221
for i in range(batch_size):
248222
mosaic_feature_map = self._get_mosaic_feature_map(output[i], c, h, w)
249223
feature_maps.append(mosaic_feature_map)
250224
return torch.cat(feature_maps)
251225

252-
def _output_hook(
253-
self, module: torch.nn.Module, inputs: Any, output: torch.Tensor
254-
) -> Dict[str, torch.Tensor]:
226+
def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
255227
"""Split combined output B0xC into BxC precition and BxCxHxW saliency map."""
256228
batch_size, _, h, w = self._feature_shape # B0xDxHxW
257229
num_classes = output.shape[1] # C
258230
predictions = output[:batch_size] # BxC
259231
saliency_maps = output[batch_size:] # BHWxC
260-
saliency_maps = saliency_maps.reshape(
261-
[batch_size, h * w, num_classes]
262-
) # BxHWxC
232+
saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes]) # BxHWxC
263233
saliency_maps = saliency_maps.transpose(1, 2) # BxCxHW
264234
if self._embed_scaling:
265235
saliency_maps = saliency_maps.reshape((batch_size * num_classes, h * w))
266236
saliency_maps = self._normalize_map(saliency_maps)
267-
saliency_maps = saliency_maps.reshape(
268-
[batch_size, num_classes, h, w]
269-
) # BxCxHxW
237+
saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w]) # BxCxHxW
270238
return {
271239
"prediction": predictions,
272240
SALIENCY_MAP_OUTPUT_NAME: saliency_maps,
273241
}
274242

275-
def _get_mosaic_feature_map(
276-
self, feature_map: torch.Tensor, c: int, h: int, w: int
277-
) -> torch.Tensor:
243+
def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w: int) -> torch.Tensor:
278244
if self._optimize_gap:
279245
# if isinstance(model_neck, GlobalAveragePooling):
280246
# Optimization workaround for the GAP case (simulate GAP with more simple compute graph)
281247
# Possible due to static sparsity of mosaic_feature_map
282248
# Makes the downstream GAP operation to be dummy
283-
feature_map_transposed = torch.flatten(feature_map, start_dim=1).transpose(
284-
0, 1
285-
)[:, :, None, None]
249+
feature_map_transposed = torch.flatten(feature_map, start_dim=1).transpose(0, 1)[:, :, None, None]
286250
mosaic_feature_map = feature_map_transposed / (h * w)
287251
else:
288252
feature_map_repeated = feature_map.repeat(h * w, 1, 1, 1)
@@ -291,9 +255,7 @@ def _get_mosaic_feature_map(
291255
for i in range(h):
292256
for j in range(w):
293257
k = spatial_order[i, j]
294-
mosaic_feature_map_mask[k, :, i, j] = torch.ones(c).to(
295-
feature_map.device
296-
)
258+
mosaic_feature_map_mask[k, :, i, j] = torch.ones(c).to(feature_map.device)
297259
mosaic_feature_map = feature_map_repeated * mosaic_feature_map_mask
298260
return mosaic_feature_map
299261

@@ -328,24 +290,16 @@ def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module:
328290
self._feature_module = None
329291
norm_modules = []
330292
for name, sub_module in module.named_modules():
331-
if (
332-
"LayerNorm" in type(sub_module).__name__
333-
or "BatchNorm" in type(sub_module).__name__
334-
or "norm1" in name
335-
):
293+
if "LayerNorm" in type(sub_module).__name__ or "BatchNorm" in type(sub_module).__name__ or "norm1" in name:
336294
norm_modules.append(sub_module)
337295

338296
if len(norm_modules) < 3:
339-
raise RuntimeError(
340-
"Feature modules with LayerNorm or BatchNorm are less than 3 in the torch model"
341-
)
297+
raise RuntimeError("Feature modules with LayerNorm or BatchNorm are less than 3 in the torch model")
342298

343299
self._feature_module = norm_modules[-3]
344300
return self._feature_module
345301

346-
def _feature_hook(
347-
self, module: torch.nn.Module, inputs: Any, output: torch.Tensor
348-
) -> torch.Tensor:
302+
def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
349303
"""feature_maps -> vertical stack of feature_maps + mosaic_feature_maps."""
350304
feature_map = output
351305
batch_size, num_tokens, dim = feature_map.shape
@@ -357,19 +311,15 @@ def _feature_hook(
357311
feature_maps.append(mosaic_feature_map)
358312
return torch.cat(feature_maps)
359313

360-
def _get_mosaic_feature_map(
361-
self, feature_map: torch.Tensor, c: int, h: int, w: int
362-
) -> torch.Tensor:
314+
def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w: int) -> torch.Tensor:
363315
num_tokens = h * w + 1
364316
mosaic_feature_map = torch.zeros(h * w, num_tokens, c).to(feature_map.device)
365317

366318
if self._use_gaussian:
367319
if self._use_cls_token:
368320
mosaic_feature_map[:, 0, :] = feature_map[0, :]
369321
feature_map_spatial = feature_map[1:, :].reshape(1, h, w, c)
370-
feature_map_spatial_repeated = feature_map_spatial.repeat(
371-
h * w, 1, 1, 1
372-
) # 196, 14, 14, 192
322+
feature_map_spatial_repeated = feature_map_spatial.repeat(h * w, 1, 1, 1) # 196, 14, 14, 192
373323

374324
spatial_order = torch.arange(h * w).reshape(h, w)
375325
gaussian = torch.tensor(
@@ -379,40 +329,26 @@ def _get_mosaic_feature_map(
379329
[1 / 16.0, 1 / 8.0, 1 / 16.0],
380330
],
381331
).to(feature_map.device)
382-
mosaic_feature_map_mask_padded = torch.zeros(h * w, h + 2, w + 2).to(
383-
feature_map.device
384-
)
332+
mosaic_feature_map_mask_padded = torch.zeros(h * w, h + 2, w + 2).to(feature_map.device)
385333
for i in range(h):
386334
for j in range(w):
387335
k = spatial_order[i, j]
388336
i_pad = i + 1
389337
j_pad = j + 1
390-
mosaic_feature_map_mask_padded[
391-
k, i_pad - 1 : i_pad + 2, j_pad - 1 : j_pad + 2
392-
] = gaussian
338+
mosaic_feature_map_mask_padded[k, i_pad - 1 : i_pad + 2, j_pad - 1 : j_pad + 2] = gaussian
393339
mosaic_feature_map_mask = mosaic_feature_map_mask_padded[:, 1:-1, 1:-1]
394-
mosaic_feature_map_mask = mosaic_feature_map_mask.unsqueeze(3).repeat(
395-
1, 1, 1, c
396-
)
340+
mosaic_feature_map_mask = mosaic_feature_map_mask.unsqueeze(3).repeat(1, 1, 1, c)
397341

398-
mosaic_fm_wo_cls_token = (
399-
feature_map_spatial_repeated * mosaic_feature_map_mask
400-
)
401-
mosaic_feature_map[:, 1:, :] = mosaic_fm_wo_cls_token.reshape(
402-
h * w, h * w, c
403-
)
342+
mosaic_fm_wo_cls_token = feature_map_spatial_repeated * mosaic_feature_map_mask
343+
mosaic_feature_map[:, 1:, :] = mosaic_fm_wo_cls_token.reshape(h * w, h * w, c)
404344
else:
405345
feature_map_repeated = feature_map.unsqueeze(0).repeat(h * w, 1, 1)
406-
mosaic_feature_map_mask = torch.zeros(h * w, num_tokens).to(
407-
feature_map.device
408-
)
346+
mosaic_feature_map_mask = torch.zeros(h * w, num_tokens).to(feature_map.device)
409347
for i in range(h * w):
410348
mosaic_feature_map_mask[i, i + 1] = torch.ones(1).to(feature_map.device)
411349
if self._use_cls_token:
412350
mosaic_feature_map_mask[:, 0] = torch.ones(1).to(feature_map.device)
413-
mosaic_feature_map_mask = mosaic_feature_map_mask.unsqueeze(2).repeat(
414-
1, 1, c
415-
)
351+
mosaic_feature_map_mask = mosaic_feature_map_mask.unsqueeze(2).repeat(1, 1, c)
416352
mosaic_feature_map = feature_map_repeated * mosaic_feature_map_mask
417353

418354
return mosaic_feature_map

tests/unit/methods/white_box/test_torch.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ def test_torch_method():
103103
assert SALIENCY_MAP_OUTPUT_NAME in output
104104

105105
class DummyMethod(TorchWhiteBoxMethod):
106-
def _feature_hook(
107-
self, module: torch.nn.Module, inputs: Any, output: torch.Tensor
108-
) -> torch.Tensor:
106+
def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
109107
output = torch.cat((output, output), dim=0)
110108
return super()._feature_hook(module, inputs, output)
111109

@@ -130,9 +128,7 @@ def _output_hook(
130128

131129
def test_prepare_model():
132130
model = DummyCNN()
133-
method = TorchWhiteBoxMethod(
134-
model=model, target_layer="feature", prepare_model=False
135-
)
131+
method = TorchWhiteBoxMethod(model=model, target_layer="feature", prepare_model=False)
136132
model_xai = method.prepare_model(load_model=False)
137133
assert method._model_compiled is None
138134
model_xai = method.prepare_model(load_model=False)
@@ -200,9 +196,7 @@ def test_reciprocam(optimize_gap: bool) -> None:
200196
batch_size = 2
201197
num_classes = 3
202198
model = DummyCNN(num_classes=num_classes)
203-
method = TorchReciproCAM(
204-
model=model, target_layer="feature", optimize_gap=optimize_gap
205-
)
199+
method = TorchReciproCAM(model=model, target_layer="feature", optimize_gap=optimize_gap)
206200
model_xai = method.prepare_model()
207201
assert has_xai(model_xai)
208202
data = np.random.rand(batch_size, 4, 5, 5)

0 commit comments

Comments
 (0)