Skip to content

Commit 01832f5

Browse files
Expose Visualizer class (#426)
* initial commit Signed-off-by: Igor Davidyuk <igor.davidyuk@intel.com> merge python version * expand visualizer docs Signed-off-by: Igor Davidyuk <igor.davidyuk@intel.com> * visualizer and notebook fixes Signed-off-by: Igor Davidyuk <igor.davidyuk@intel.com> * fix benchmarker tests Signed-off-by: Igor Davidyuk <igor.davidyuk@intel.com> * fix 102 103 use case notebooks Signed-off-by: Igor Davidyuk <igor.davidyuk@intel.com> * fix 005 notebook example code Signed-off-by: Igor Davidyuk <igor.davidyuk@intel.com> * change visualizer import path Signed-off-by: Igor Davidyuk <igor.davidyuk@intel.com> * 008 notebook add xai explanation text Signed-off-by: Igor Davidyuk <igor.davidyuk@intel.com> * Apply suggestions from code review Changing the visualizer import path Co-authored-by: Ludo Cornelissen <ludo.cornelissen@intel.com> --------- Signed-off-by: Igor Davidyuk <igor.davidyuk@intel.com> Co-authored-by: Ludo Cornelissen <ludo.cornelissen@intel.com>
1 parent efd325c commit 01832f5

16 files changed

+264
-202
lines changed

docs/source/api_reference.rst

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ API Reference
99
Data models <geti_sdk.data_models>
1010
Import Export module <geti_sdk.import_export>
1111
Deployment <geti_sdk.deployment>
12+
Prediction Visualization <geti_sdk.prediction_visualization>
1213
HTTP session <geti_sdk.http_session>
1314
REST converters <geti_sdk.rest_converters>
1415
REST clients <geti_sdk.rest_clients>

geti_sdk/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@
111111
"""
112112

113113
from .geti import Geti
114+
from .prediction_visualization.visualizer import Visualizer
114115

115116
__version__ = "2.1.0"
116117

117-
__all__ = ["Geti"]
118+
__all__ = ["Geti", "Visualizer"]

geti_sdk/benchmarking/benchmarker.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@
3333
Video,
3434
)
3535
from geti_sdk.deployment import Deployment
36+
from geti_sdk.prediction_visualization.visualizer import Visualizer
3637
from geti_sdk.rest_clients import ImageClient, ModelClient, TrainingClient, VideoClient
3738
from geti_sdk.rest_clients.prediction_client import PredictionClient
3839
from geti_sdk.utils.plot_helpers import (
3940
concat_prediction_results,
4041
pad_image_and_put_caption,
41-
show_image_with_annotation_scene,
4242
)
4343

4444
from .utils import get_system_info, load_benchmark_media, suppress_log_output
@@ -859,6 +859,8 @@ def compare_predictions(
859859
with open(throughput_benchmark_results, "r") as results_file:
860860
throughput_benchmark_results = list(csv.DictReader(results_file))
861861

862+
visualizer = Visualizer()
863+
862864
# Performe inferece
863865
with logging_redirect_tqdm(tqdm_class=tqdm):
864866
results: List[List[np.ndarray]] = []
@@ -890,9 +892,7 @@ def compare_predictions(
890892
f"failed. Inference failed with error: `{e}`"
891893
)
892894
if success:
893-
image_with_prediction = show_image_with_annotation_scene(
894-
image, prediction, show_results=False
895-
)
895+
image_with_prediction = visualizer.draw(image, prediction)
896896
image_with_prediction = cv2.cvtColor(
897897
image_with_prediction, cv2.COLOR_BGR2RGB
898898
)
@@ -953,8 +953,8 @@ def compare_predictions(
953953
if include_online_prediction_for_active_model:
954954
logging.info("Predicting on the platform using the active model")
955955
online_prediction_result = self._predict_using_active_model(image)
956-
image_with_prediction = show_image_with_annotation_scene(
957-
image, online_prediction_result["prediction"], show_results=False
956+
image_with_prediction = visualizer.draw(
957+
image, online_prediction_result["prediction"]
958958
)
959959
image_with_prediction = cv2.cvtColor(
960960
image_with_prediction, cv2.COLOR_BGR2RGB

geti_sdk/deployment/resources/OVMS_README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
6262
predictions = deployment.infer(image=image)
6363

6464
# Show inference result
65-
from geti_sdk.utils import show_image_with_annotation_scene
66-
show_image_with_annotation_scene(image=image, annotation_scene=predictions);
65+
from geti_sdk import Visualizer
66+
visualizer = Visualizer()
67+
result_image = visualizer.draw(image=image, annotation_scene=predictions)
68+
visualizer.show_window(result_image)
6769
```
6870

6971
The example uses a sample image, please make sure to replace it with your own.

geti_sdk/post_inference_hooks/actions/file_system_data_collection.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323
from geti_sdk.data_models import Prediction
2424
from geti_sdk.deployment.inference_hook_interfaces import PostInferenceAction
25+
from geti_sdk.prediction_visualization.visualizer import Visualizer
2526
from geti_sdk.rest_converters import PredictionRESTConverter
26-
from geti_sdk.utils import show_image_with_annotation_scene
2727

2828

2929
class FileSystemDataCollection(PostInferenceAction):
@@ -81,6 +81,7 @@ def __init__(
8181
self.save_predictions = save_predictions
8282
self.save_scores = save_scores
8383
self.save_overlays = save_overlays
84+
self.visualizer = Visualizer()
8485

8586
self._repr_info_ = (
8687
f"target_folder=`{target_folder}`, "
@@ -147,12 +148,8 @@ def __call__(
147148

148149
if self.save_overlays:
149150
overlay_path = os.path.join(self.overlays_path, filename + ".jpg")
150-
show_image_with_annotation_scene(
151-
image=image,
152-
annotation_scene=prediction,
153-
filepath=overlay_path,
154-
show_results=False,
155-
)
151+
result = self.visualizer.draw(image, prediction)
152+
self.visualizer.save_image(result, overlay_path)
156153
except Exception as e:
157154
logging.exception(e, stack_info=True, exc_info=True)
158155

geti_sdk/prediction_visualization/__init__.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,48 @@
1616
Introduction
1717
------------
1818
19-
The `prediction_visualization` package provides classes for visualizing models predictions.
20-
Currently, the user interfaces to this package are available in the :py:mod:`~geti_sdk.utils.plot_helpers` module.
19+
The `prediction_visualization` package provides classes for visualizing models predictions and media annotations.
20+
Aditionally, shortend interface to this package is available through the :py:mod:`~geti_sdk.utils.plot_helpers` module.
21+
22+
The main :py:class:`~geti_sdk.prediction_visualization.visualizer.Visualizer` class is a flexible utility class for working
23+
with Geti-SDK Prediction and Annotation object. You can initialize the Visualizer with the desired settings and then use it to draw
24+
the annotations on the input image.
25+
26+
.. code-block:: python
27+
28+
from geti_sdk import Visualizer
29+
30+
visualizer = Visualizer(
31+
show_labels=True,
32+
show_confidence=True,
33+
show_count=False,
34+
)
35+
36+
# Obtain a prediction from the Intel Geti platfor server or a local deployment.
37+
...
38+
39+
# Visualize the prediction on the input image.
40+
result = visualizer.draw(
41+
numpy_image,
42+
prediction,
43+
fill_shapes=True,
44+
confidence_threshold=0.4,
45+
)
46+
visualizer.show_in_notebook(result)
47+
48+
In case the Prediction was generated with a model that supports explainable AI functionality, the Visualizer can also display
49+
the explanation for the prediction.
50+
51+
.. code-block:: python
52+
image_with_saliency_map = visualizer.explain_label(
53+
numpy_image,
54+
prediction,
55+
label_name="Cat",
56+
opacity=0.5,
57+
show_predictions=True,
58+
)
59+
visualizer.save_image(image_with_saliency_map, "./explained_prediction.jpg")
60+
visualizer.show_window(image_with_saliency_map) # When called in a script
2161
2262
Module contents
2363
---------------

geti_sdk/prediction_visualization/visualizer.py

+56-14
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
"""The package provides the Visualizer class for models predictions visualization."""
1616

1717

18-
from typing import Optional
18+
from os import PathLike
19+
from typing import List, Optional, Union
1920

2021
import cv2
2122
import numpy as np
23+
from IPython.display import display
24+
from PIL import Image
2225

2326
from geti_sdk.data_models.annotation_scene import AnnotationScene
27+
from geti_sdk.data_models.containers.media_list import MediaList
28+
from geti_sdk.data_models.media import VideoFrame
2429
from geti_sdk.data_models.predictions import Prediction
2530
from geti_sdk.prediction_visualization.shape_drawer import ShapeDrawer
2631

@@ -44,8 +49,6 @@ def __init__(
4449
show_confidence: bool = True,
4550
show_count: bool = False,
4651
is_one_label: bool = False,
47-
delay: Optional[int] = None,
48-
output: Optional[str] = None,
4952
) -> None:
5053
"""
5154
Initialize the Visualizer.
@@ -55,19 +58,12 @@ def __init__(
5558
:param show_confidence: Show confidence on the output image
5659
:param show_count: Show count of the shapes on the output image
5760
:param is_one_label: Show only one label on the output image
58-
:param delay: Delay time for the output image
59-
:param output: Path to save the output image
6061
"""
6162
self.window_name = "Window" if window_name is None else window_name
6263
self.shape_drawer = ShapeDrawer(
6364
show_count, is_one_label, show_labels, show_confidence
6465
)
6566

66-
self.delay = delay
67-
if delay is None:
68-
self.delay = 1
69-
self.output = output
70-
7167
def draw(
7268
self,
7369
image: np.ndarray,
@@ -90,7 +86,7 @@ def draw(
9086
if confidence_threshold is not None:
9187
annotation = annotation.filter_by_confidence(confidence_threshold)
9288
result = self.shape_drawer.draw(
93-
image, annotation, labels=[], fill_shapes=fill_shapes
89+
image.copy(), annotation, labels=[], fill_shapes=fill_shapes
9490
)
9591
return result
9692

@@ -140,7 +136,53 @@ def explain_label(
140136
result = self.draw(result, filtered_prediction, fill_shapes=False)
141137
return result
142138

143-
def show(self, image: np.ndarray) -> None:
139+
@staticmethod
140+
def save_image(image: np.ndarray, output_path: PathLike) -> None:
141+
"""
142+
Save the image to the output path.
143+
144+
:param image: Image in RGB format to be saved
145+
:param output_path: Path to save the image
146+
"""
147+
bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
148+
cv2.imwrite(output_path, bgr_image)
149+
150+
@staticmethod
151+
def save_video(
152+
video_frames: MediaList[VideoFrame],
153+
annotation_scenes: List[Union[AnnotationScene, Prediction]],
154+
output_path: PathLike,
155+
fps: float = 1,
156+
) -> None:
157+
"""
158+
Save the video to the output path.
159+
160+
:param video_frames: List of video frames
161+
:param annotation_scenes: List of annotation scenes to be drawn on the video frames
162+
:param output_path: Path to save the image
163+
"""
164+
out_writer = cv2.VideoWriter(
165+
filename=f"{output_path}",
166+
fourcc=cv2.VideoWriter_fourcc("M", "J", "P", "G"),
167+
fps=fps,
168+
frameSize=(
169+
video_frames[0].media_information.width,
170+
video_frames[0].media_information.height,
171+
),
172+
)
173+
for frame, annotation in zip(video_frames, annotation_scenes):
174+
out_writer.write(frame)
175+
176+
@staticmethod
177+
def show_in_notebook(image: np.ndarray) -> None:
178+
"""
179+
Show the image in the Jupyter notebook.
180+
181+
:param image: Image to be shown in RGB format
182+
"""
183+
display(Image.fromarray(image))
184+
185+
def show_window(self, image: np.ndarray) -> None:
144186
"""
145187
Show result image.
146188
@@ -149,6 +191,6 @@ def show(self, image: np.ndarray) -> None:
149191
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
150192
cv2.imshow(self.window_name, image_bgr)
151193

152-
def is_quit(self) -> bool:
194+
def is_quit(self, delay: int = 1) -> bool:
153195
"""Check user wish to quit."""
154-
return ord("q") == cv2.waitKey(self.delay)
196+
return ord("q") == cv2.waitKey(delay)

notebooks/003_upload_and_predict_image.ipynb

+12-16
Original file line numberDiff line numberDiff line change
@@ -207,15 +207,19 @@
207207
"metadata": {},
208208
"outputs": [],
209209
"source": [
210-
"from geti_sdk.utils import show_image_with_annotation_scene\n",
210+
"import cv2\n",
211+
"\n",
212+
"from geti_sdk import Visualizer\n",
211213
"\n",
212214
"# To visualise the image, we have to retrieve the pixel data from the platform using the `image.get_data` method. The actual pixel data is\n",
213215
"# downloaded and cached only on the first call to this method\n",
214216
"image.get_data(geti.session)\n",
217+
"numpy_image = image.numpy\n",
215218
"\n",
216-
"show_image_with_annotation_scene(\n",
217-
" image, prediction, show_in_notebook=True, channel_order=\"bgr\"\n",
218-
");"
219+
"visualizer = Visualizer()\n",
220+
"image_rgb = cv2.cvtColor(numpy_image, cv2.COLOR_BGR2RGB)\n",
221+
"result = visualizer.draw(image_rgb, prediction)\n",
222+
"visualizer.show_in_notebook(result)"
219223
]
220224
},
221225
{
@@ -240,18 +244,10 @@
240244
" visualise_output=False,\n",
241245
" delete_after_prediction=False,\n",
242246
")\n",
243-
"show_image_with_annotation_scene(\n",
244-
" quick_image, quick_prediction, show_in_notebook=True, channel_order=\"bgr\"\n",
245-
");"
247+
"quick_image_rgb = cv2.cvtColor(quick_image.numpy, cv2.COLOR_BGR2RGB)\n",
248+
"quick_result = visualizer.draw(quick_image_rgb, quick_prediction)\n",
249+
"visualizer.show_in_notebook(quick_result)"
246250
]
247-
},
248-
{
249-
"cell_type": "code",
250-
"execution_count": null,
251-
"id": "51090376-c85e-4af3-9ff8-b030934fd095",
252-
"metadata": {},
253-
"outputs": [],
254-
"source": []
255251
}
256252
],
257253
"metadata": {
@@ -270,7 +266,7 @@
270266
"name": "python",
271267
"nbconvert_exporter": "python",
272268
"pygments_lexer": "ipython3",
273-
"version": "3.8.16"
269+
"version": "3.10.12"
274270
}
275271
},
276272
"nbformat": 4,

notebooks/005_modify_image.ipynb

+11-10
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@
219219
"metadata": {},
220220
"outputs": [],
221221
"source": [
222+
"from geti_sdk import Visualizer\n",
222223
"from geti_sdk.rest_clients import AnnotationClient\n",
223-
"from geti_sdk.utils import show_image_with_annotation_scene\n",
224224
"\n",
225225
"annotation_client = AnnotationClient(\n",
226226
" session=geti.session, workspace_id=geti.workspace_id, project=project\n",
@@ -231,9 +231,11 @@
231231
"\n",
232232
"# Inspect the annotation\n",
233233
"print(annotation.overview)\n",
234-
"show_image_with_annotation_scene(\n",
235-
" image, annotation, show_in_notebook=True, channel_order=\"bgr\"\n",
236-
");"
234+
"\n",
235+
"visualizer = Visualizer()\n",
236+
"image_rgb = cv2.cvtColor(image.numpy, cv2.COLOR_BGR2RGB)\n",
237+
"result = visualizer.draw(image_rgb, annotation)\n",
238+
"visualizer.show_in_notebook(result)"
237239
]
238240
},
239241
{
@@ -276,11 +278,10 @@
276278
")\n",
277279
"\n",
278280
"# Inspect the annotation\n",
279-
"show_image_with_annotation_scene(\n",
280-
" grayscale_image.get_data(geti.session),\n",
281-
" grayscale_annotation,\n",
282-
" show_in_notebook=True,\n",
283-
");"
281+
"result = visualizer.draw(\n",
282+
" grayscale_image.get_data(geti.session).numpy, grayscale_annotation\n",
283+
")\n",
284+
"visualizer.show_in_notebook(result)"
284285
]
285286
},
286287
{
@@ -342,7 +343,7 @@
342343
"name": "python",
343344
"nbconvert_exporter": "python",
344345
"pygments_lexer": "ipython3",
345-
"version": "3.8.16"
346+
"version": "3.10.12"
346347
}
347348
},
348349
"nbformat": 4,

0 commit comments

Comments
 (0)