Skip to content

Commit 63056fc

Browse files
Add Classification And Detection Scene (#259)
* Add classification scene Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add detection scene Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add tests Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add title to overlay Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Pass name and confidence separately Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Fix tests Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> --------- Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com>
1 parent 4d4bf20 commit 63056fc

File tree

6 files changed

+126
-9
lines changed

6 files changed

+126
-9
lines changed

src/python/model_api/models/result/detection.py

+5
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ def label_names(self, value):
111111

112112
@property
113113
def saliency_map(self):
114+
"""Saliency map for XAI.
115+
116+
Returns:
117+
np.ndarray: Saliency map in dim of (B, N_CLASSES, H, W).
118+
"""
114119
return self._saliency_map
115120

116121
@saliency_map.setter

src/python/model_api/visualizer/layout/hstack.py

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import PIL
1111

12+
from model_api.visualizer.primitive import Overlay
13+
1214
from .layout import Layout
1315

1416
if TYPE_CHECKING:
@@ -31,6 +33,8 @@ def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, sc
3133
images = []
3234
for _primitive in scene.get_primitives(primitive):
3335
image_ = _primitive.compute(image.copy())
36+
if isinstance(_primitive, Overlay):
37+
image_ = Overlay.overlay_labels(image=image_, labels=_primitive.label)
3438
images.append(image_)
3539
return self._stitch(*images)
3640
return None

src/python/model_api/visualizer/primitive/overlay.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
from __future__ import annotations
77

8+
from typing import Union
9+
810
import numpy as np
911
import PIL
12+
from PIL import ImageFont
1013

1114
from .primitive import Primitive
1215

@@ -18,11 +21,18 @@ class Overlay(Primitive):
1821
1922
Args:
2023
image (PIL.Image | np.ndarray): Image to be overlaid.
24+
label (str | None): Optional label name to overlay.
2125
opacity (float): Opacity of the overlay.
2226
"""
2327

24-
def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None:
28+
def __init__(
29+
self,
30+
image: PIL.Image | np.ndarray,
31+
opacity: float = 0.4,
32+
label: Union[str, None] = None,
33+
) -> None:
2534
self.image = self._to_pil(image)
35+
self.label = label
2636
self.opacity = opacity
2737

2838
def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
@@ -33,3 +43,22 @@ def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
3343
def compute(self, image: PIL.Image) -> PIL.Image:
3444
image_ = self.image.resize(image.size)
3545
return PIL.Image.blend(image, image_, self.opacity)
46+
47+
@classmethod
48+
def overlay_labels(cls, image: PIL.Image, labels: Union[list[str], str, None] = None) -> PIL.Image:
49+
"""Draw labels at the bottom center of the image.
50+
51+
This is handy when you want to add a label to the image.
52+
"""
53+
if labels is not None:
54+
labels = [labels] if isinstance(labels, str) else labels
55+
font = ImageFont.load_default(size=18)
56+
buffer_y = 5
57+
dummy_image = PIL.Image.new("RGB", (1, 1))
58+
draw = PIL.ImageDraw.Draw(dummy_image)
59+
textbox = draw.textbbox((0, 0), ", ".join(labels), font=font)
60+
image_ = PIL.Image.new("RGB", (textbox[2] - textbox[0], textbox[3] + buffer_y - textbox[1]), "white")
61+
draw = PIL.ImageDraw.Draw(image_)
62+
draw.text((0, 0), ", ".join(labels), font=font, fill="black")
63+
image.paste(image_, (image.width // 2 - image_.width // 2, image.height - image_.height - buffer_y))
64+
return image

src/python/model_api/visualizer/scene/classification.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66
from typing import Union
77

8+
import cv2
89
from PIL import Image
910

1011
from model_api.models.result import ClassificationResult
1112
from model_api.visualizer.layout import Flatten, Layout
12-
from model_api.visualizer.primitive import Overlay
13+
from model_api.visualizer.primitive import Label, Overlay
1314

1415
from .scene import Scene
1516

@@ -18,9 +19,28 @@ class ClassificationScene(Scene):
1819
"""Classification Scene."""
1920

2021
def __init__(self, image: Image, result: ClassificationResult, layout: Union[Layout, None] = None) -> None:
21-
self.image = image
22-
self.result = result
22+
super().__init__(
23+
base=image,
24+
label=self._get_labels(result),
25+
overlay=self._get_overlays(result),
26+
layout=layout,
27+
)
28+
29+
def _get_labels(self, result: ClassificationResult) -> list[Label]:
30+
labels = []
31+
if result.top_labels is not None and len(result.top_labels) > 0:
32+
for label in result.top_labels:
33+
if label.name is not None:
34+
labels.append(Label(label=label.name, score=label.confidence))
35+
return labels
36+
37+
def _get_overlays(self, result: ClassificationResult) -> list[Overlay]:
38+
overlays = []
39+
if result.saliency_map is not None and result.saliency_map.size > 0:
40+
saliency_map = cv2.cvtColor(result.saliency_map, cv2.COLOR_BGR2RGB)
41+
overlays.append(Overlay(saliency_map))
42+
return overlays
2343

2444
@property
2545
def default_layout(self) -> Layout:
26-
return Flatten(Overlay)
46+
return Flatten(Overlay, Label)

src/python/model_api/visualizer/scene/detection.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
from typing import Union
77

8+
import cv2
89
from PIL import Image
910

1011
from model_api.models.result import DetectionResult
11-
from model_api.visualizer.layout import Layout
12+
from model_api.visualizer.layout import Flatten, HStack, Layout
13+
from model_api.visualizer.primitive import BoundingBox, Label, Overlay
1214

1315
from .scene import Scene
1416

@@ -17,5 +19,31 @@ class DetectionScene(Scene):
1719
"""Detection Scene."""
1820

1921
def __init__(self, image: Image, result: DetectionResult, layout: Union[Layout, None] = None) -> None:
20-
self.image = image
21-
self.result = result
22+
super().__init__(
23+
base=image,
24+
bounding_box=self._get_bounding_boxes(result),
25+
overlay=self._get_overlays(result),
26+
layout=layout,
27+
)
28+
29+
def _get_overlays(self, result: DetectionResult) -> list[Overlay]:
30+
overlays = []
31+
# Add only the overlays that are predicted
32+
label_index_mapping = dict(zip(result.labels, result.label_names))
33+
for label_index, label_name in label_index_mapping.items():
34+
# Index 0 as it assumes only one batch
35+
saliency_map = cv2.applyColorMap(result.saliency_map[0][label_index], cv2.COLORMAP_JET)
36+
overlays.append(Overlay(saliency_map, label=label_name.title()))
37+
return overlays
38+
39+
def _get_bounding_boxes(self, result: DetectionResult) -> list[BoundingBox]:
40+
bounding_boxes = []
41+
for score, label_name, bbox in zip(result.scores, result.label_names, result.bboxes):
42+
x1, y1, x2, y2 = bbox
43+
label = f"{label_name} ({score:.2f})"
44+
bounding_boxes.append(BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2, label=label))
45+
return bounding_boxes
46+
47+
@property
48+
def default_layout(self) -> Layout:
49+
return HStack(Flatten(BoundingBox, Label), Overlay)

tests/python/unit/visualizer/test_scene.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import numpy as np
99
from PIL import Image
1010

11-
from model_api.models.result import AnomalyResult
11+
from model_api.models.result import AnomalyResult, ClassificationResult, DetectionResult
12+
from model_api.models.result.classification import Label
1213
from model_api.visualizer import Visualizer
1314

1415

@@ -32,3 +33,33 @@ def test_anomaly_scene(mock_image: Image, tmpdir: Path):
3233
visualizer = Visualizer()
3334
visualizer.save(mock_image, anomaly_result, tmpdir / "anomaly_scene.jpg")
3435
assert Path(tmpdir / "anomaly_scene.jpg").exists()
36+
37+
38+
def test_classification_scene(mock_image: Image, tmpdir: Path):
39+
"""Test if the classification scene is created."""
40+
classification_result = ClassificationResult(
41+
top_labels=[
42+
Label(name="cat", confidence=0.95),
43+
Label(name="dog", confidence=0.90),
44+
],
45+
saliency_map=np.ones(mock_image.size, dtype=np.uint8),
46+
)
47+
visualizer = Visualizer()
48+
visualizer.save(
49+
mock_image, classification_result, tmpdir / "classification_scene.jpg"
50+
)
51+
assert Path(tmpdir / "classification_scene.jpg").exists()
52+
53+
54+
def test_detection_scene(mock_image: Image, tmpdir: Path):
55+
"""Test if the detection scene is created."""
56+
detection_result = DetectionResult(
57+
bboxes=np.array([[0, 0, 128, 128], [32, 32, 96, 96]]),
58+
labels=np.array([0, 1]),
59+
label_names=["person", "car"],
60+
scores=np.array([0.85, 0.75]),
61+
saliency_map=(np.ones((1, 2, 6, 8)) * 255).astype(np.uint8),
62+
)
63+
visualizer = Visualizer()
64+
visualizer.save(mock_image, detection_result, tmpdir / "detection_scene.jpg")
65+
assert Path(tmpdir / "detection_scene.jpg").exists()

0 commit comments

Comments
 (0)