Skip to content

Commit db0bf5a

Browse files
Add Keypoint and Instance Segmentation support in visualizer (#261)
* Add Keypoint and Instance Segmentation support in visualizer - Introduced Keypoint primitive for visualizing keypoints on images. - Updated visualizer to handle Instance Segmentation results and added corresponding scene. - Refactored existing scenes to integrate new functionalities, including KeypointScene and InstanceSegmentationScene. - Enhanced Polygon primitive with opacity and outline width options. - Updated tests to cover new features and ensure functionality. This commit enhances the visualizer's capabilities for handling keypoints and instance segmentation, improving the overall visualization experience. * fix saliency map colour conversion Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add visualization example for VisionAPI Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> --------- Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com>
1 parent 63056fc commit db0bf5a

File tree

17 files changed

+408
-36
lines changed

17 files changed

+408
-36
lines changed
+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Visualization Example
2+
3+
This example demonstrates how to use the Visualizer in VisionAPI.
4+
5+
## Prerequisites
6+
7+
Install Model API from source. Please refer to the main [README](../../../README.md) for details.
8+
9+
## Run example
10+
11+
To run the example, please execute the following command:
12+
13+
```bash
14+
python run.py --image <path_to_image> --model <path_to_model>.xml --output <path_to_output_image>
15+
```

examples/python/visualization/run.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""Visualization Example."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
import argparse
7+
from argparse import Namespace
8+
9+
import cv2
10+
import numpy as np
11+
from PIL import Image
12+
13+
from model_api.models import Model
14+
from model_api.visualizer import Visualizer
15+
16+
17+
def main(args: Namespace):
18+
image = Image.open(args.image)
19+
20+
model = Model.create_model(args.model)
21+
22+
image_array = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
23+
predictions = model(image_array)
24+
visualizer = Visualizer()
25+
26+
if args.output:
27+
visualizer.save(image=image, result=predictions, path=args.output)
28+
else:
29+
visualizer.show(image=image, result=predictions)
30+
31+
32+
if __name__ == "__main__":
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("--image", type=str, required=True)
35+
parser.add_argument("--model", type=str, required=True)
36+
parser.add_argument("--output", type=str, required=False)
37+
args = parser.parse_args()
38+
main(args)

src/python/model_api/visualizer/__init__.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,19 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
from .layout import Flatten, HStack, Layout
7-
from .primitive import BoundingBox, Label, Overlay, Polygon
7+
from .primitive import BoundingBox, Keypoint, Label, Overlay, Polygon
88
from .scene import Scene
99
from .visualizer import Visualizer
1010

11-
__all__ = ["BoundingBox", "Label", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
11+
__all__ = [
12+
"BoundingBox",
13+
"Keypoint",
14+
"Label",
15+
"Overlay",
16+
"Polygon",
17+
"Scene",
18+
"Visualizer",
19+
"Layout",
20+
"Flatten",
21+
"HStack",
22+
]

src/python/model_api/visualizer/primitive/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
from .bounding_box import BoundingBox
7+
from .keypoints import Keypoint
78
from .label import Label
89
from .overlay import Overlay
910
from .polygon import Polygon
1011
from .primitive import Primitive
1112

12-
__all__ = ["Primitive", "BoundingBox", "Label", "Overlay", "Polygon"]
13+
__all__ = ["Primitive", "BoundingBox", "Label", "Overlay", "Polygon", "Keypoint"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Keypoints primitive."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from typing import Union
7+
8+
import numpy as np
9+
from PIL import Image, ImageDraw, ImageFont
10+
11+
from .primitive import Primitive
12+
13+
14+
class Keypoint(Primitive):
15+
"""Keypoint primitive.
16+
17+
Args:
18+
keypoints (np.ndarray): Keypoints. Shape: (N, 2)
19+
scores (np.ndarray | None): Scores. Shape: (N,). Defaults to None.
20+
color (str | tuple[int, int, int]): Color of the keypoints. Defaults to "purple".
21+
"""
22+
23+
def __init__(
24+
self,
25+
keypoints: np.ndarray,
26+
scores: Union[np.ndarray, None] = None,
27+
color: Union[str, tuple[int, int, int]] = "purple",
28+
keypoint_size: int = 3,
29+
) -> None:
30+
self.keypoints = self._validate_keypoints(keypoints)
31+
self.scores = scores
32+
self.color = color
33+
self.keypoint_size = keypoint_size
34+
35+
def compute(self, image: Image) -> Image:
36+
"""Draw keypoints on the image."""
37+
draw = ImageDraw.Draw(image)
38+
for keypoint in self.keypoints:
39+
draw.ellipse(
40+
(
41+
keypoint[0] - self.keypoint_size,
42+
keypoint[1] - self.keypoint_size,
43+
keypoint[0] + self.keypoint_size,
44+
keypoint[1] + self.keypoint_size,
45+
),
46+
fill=self.color,
47+
)
48+
49+
if self.scores is not None:
50+
font = ImageFont.load_default(size=18)
51+
for score, keypoint in zip(self.scores, self.keypoints):
52+
textbox = draw.textbbox((0, 0), f"{score:.2f}", font=font)
53+
draw.text(
54+
(keypoint[0] - textbox[2] // 2, keypoint[1] + self.keypoint_size),
55+
f"{score:.2f}",
56+
font=font,
57+
fill=self.color,
58+
)
59+
return image
60+
61+
def _validate_keypoints(self, keypoints: np.ndarray) -> np.ndarray:
62+
if keypoints.shape[1] != 2:
63+
msg = "Keypoints must have shape (N, 2)"
64+
raise ValueError(msg)
65+
return keypoints

src/python/model_api/visualizer/primitive/polygon.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@
55

66
from __future__ import annotations
77

8+
import logging
89
from typing import TYPE_CHECKING
910

1011
import cv2
11-
from PIL import Image, ImageDraw
12+
from PIL import Image, ImageColor, ImageDraw
1213

1314
from .primitive import Primitive
1415

1516
if TYPE_CHECKING:
1617
import numpy as np
1718

19+
logger = logging.getLogger(__name__)
20+
1821

1922
class Polygon(Primitive):
2023
"""Polygon primitive.
@@ -38,9 +41,13 @@ def __init__(
3841
points: list[tuple[int, int]] | None = None,
3942
mask: np.ndarray | None = None,
4043
color: str | tuple[int, int, int] = "blue",
44+
opacity: float = 0.4,
45+
outline_width: int = 2,
4146
) -> None:
4247
self.points = self._get_points(points, mask)
4348
self.color = color
49+
self.opacity = opacity
50+
self.outline_width = outline_width
4451

4552
def _get_points(self, points: list[tuple[int, int]] | None, mask: np.ndarray | None) -> list[tuple[int, int]]:
4653
"""Get points from either points or mask.
@@ -76,6 +83,13 @@ def _get_points_from_mask(self, mask: np.ndarray) -> list[tuple[int, int]]:
7683
List of points.
7784
"""
7885
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
86+
# incase of multiple contours, use the one with the largest area
87+
if len(contours) > 1:
88+
logger.warning("Multiple contours found in the mask. Using the largest one.")
89+
contours = sorted(contours, key=cv2.contourArea, reverse=True)
90+
if len(contours) == 0:
91+
msg = "No contours found in the mask."
92+
raise ValueError(msg)
7993
points_ = contours[0].squeeze().tolist()
8094
return [tuple(point) for point in points_]
8195

@@ -88,6 +102,8 @@ def compute(self, image: Image) -> Image:
88102
Returns:
89103
Image with the polygon drawn on it.
90104
"""
91-
draw = ImageDraw.Draw(image)
92-
draw.polygon(self.points, fill=self.color)
105+
draw = ImageDraw.Draw(image, "RGBA")
106+
# Draw polygon with darker edge and a semi-transparent fill.
107+
ink = ImageColor.getrgb(self.color)
108+
draw.polygon(self.points, fill=(*ink, int(255 * self.opacity)), outline=self.color, width=self.outline_width)
93109
return image

src/python/model_api/visualizer/scene/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
from .detection import DetectionScene
99
from .keypoint import KeypointScene
1010
from .scene import Scene
11-
from .segmentation import SegmentationScene
11+
from .segmentation import InstanceSegmentationScene, SegmentationScene
1212
from .visual_prompting import VisualPromptingScene
1313

1414
__all__ = [
1515
"AnomalyScene",
1616
"ClassificationScene",
1717
"DetectionScene",
18+
"InstanceSegmentationScene",
1819
"KeypointScene",
1920
"Scene",
2021
"SegmentationScene",

src/python/model_api/visualizer/scene/detection.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ def _get_overlays(self, result: DetectionResult) -> list[Overlay]:
3232
label_index_mapping = dict(zip(result.labels, result.label_names))
3333
for label_index, label_name in label_index_mapping.items():
3434
# Index 0 as it assumes only one batch
35-
saliency_map = cv2.applyColorMap(result.saliency_map[0][label_index], cv2.COLORMAP_JET)
36-
overlays.append(Overlay(saliency_map, label=label_name.title()))
35+
if result.saliency_map is not None and result.saliency_map.size > 0:
36+
saliency_map = cv2.applyColorMap(result.saliency_map[0][label_index], cv2.COLORMAP_JET)
37+
saliency_map = cv2.cvtColor(saliency_map, cv2.COLOR_BGR2RGB)
38+
overlays.append(Overlay(saliency_map, label=label_name.title()))
3739
return overlays
3840

3941
def _get_bounding_boxes(self, result: DetectionResult) -> list[BoundingBox]:

src/python/model_api/visualizer/scene/keypoint.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,30 @@
33
# Copyright (C) 2024 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6+
from typing import Union
7+
8+
from PIL import Image
9+
610
from model_api.models.result import DetectedKeypoints
711
from model_api.visualizer.layout import Flatten, Layout
8-
from model_api.visualizer.primitive import Overlay
12+
from model_api.visualizer.primitive import Keypoint
913

1014
from .scene import Scene
1115

1216

1317
class KeypointScene(Scene):
1418
"""Keypoint Scene."""
1519

16-
def __init__(self, result: DetectedKeypoints) -> None:
17-
self.result = result
20+
def __init__(self, image: Image, result: DetectedKeypoints, layout: Union[Layout, None] = None) -> None:
21+
super().__init__(
22+
base=image,
23+
keypoints=self._get_keypoints(result),
24+
layout=layout,
25+
)
26+
27+
def _get_keypoints(self, result: DetectedKeypoints) -> list[Keypoint]:
28+
return [Keypoint(result.keypoints, result.scores)]
1829

1930
@property
2031
def default_layout(self) -> Layout:
21-
return Flatten(Overlay)
32+
return Flatten(Keypoint)

src/python/model_api/visualizer/scene/scene.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111
from PIL import Image
1212

13-
from model_api.visualizer.primitive import BoundingBox, Label, Overlay, Polygon, Primitive
13+
from model_api.visualizer.primitive import BoundingBox, Keypoint, Label, Overlay, Polygon, Primitive
1414

1515
if TYPE_CHECKING:
1616
from pathlib import Path
@@ -31,13 +31,15 @@ def __init__(
3131
label: Label | list[Label] | None = None,
3232
overlay: Overlay | list[Overlay] | np.ndarray | None = None,
3333
polygon: Polygon | list[Polygon] | None = None,
34+
keypoints: Keypoint | list[Keypoint] | np.ndarray | None = None,
3435
layout: Layout | None = None,
3536
) -> None:
3637
self.base = base
3738
self.overlay = self._to_overlay(overlay)
3839
self.bounding_box = self._to_bounding_box(bounding_box)
3940
self.label = self._to_label(label)
4041
self.polygon = self._to_polygon(polygon)
42+
self.keypoints = self._to_keypoints(keypoints)
4143
self.layout = layout
4244

4345
def show(self) -> None:
@@ -60,6 +62,8 @@ def has_primitives(self, primitive: type[Primitive]) -> bool:
6062
return bool(self.label)
6163
if primitive == Polygon:
6264
return bool(self.polygon)
65+
if primitive == Keypoint:
66+
return bool(self.keypoints)
6367
return False
6468

6569
def get_primitives(self, primitive: type[Primitive]) -> list[Primitive]:
@@ -86,6 +90,8 @@ def get_primitives(self, primitive: type[Primitive]) -> list[Primitive]:
8690
primitives = cast("list[Primitive]", self.label)
8791
elif primitive == Polygon:
8892
primitives = cast("list[Primitive]", self.polygon)
93+
elif primitive == Keypoint:
94+
primitives = cast("list[Primitive]", self.keypoints)
8995
else:
9096
msg = f"Primitive {primitive} not found"
9197
raise ValueError(msg)
@@ -119,3 +125,10 @@ def _to_polygon(self, polygon: Polygon | list[Polygon] | None) -> list[Polygon]
119125
if isinstance(polygon, Polygon):
120126
return [polygon]
121127
return polygon
128+
129+
def _to_keypoints(self, keypoints: Keypoint | list[Keypoint] | np.ndarray | None) -> list[Keypoint] | None:
130+
if isinstance(keypoints, Keypoint):
131+
return [keypoints]
132+
if isinstance(keypoints, np.ndarray):
133+
return [Keypoint(keypoints)]
134+
return keypoints

src/python/model_api/visualizer/scene/segmentation.py

-15
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Segmentation Scene."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from .instance_segmentation import InstanceSegmentationScene
7+
from .segmentation import SegmentationScene
8+
9+
__all__ = [
10+
"InstanceSegmentationScene",
11+
"SegmentationScene",
12+
]

0 commit comments

Comments
 (0)