15
15
"""The package provides the Visualizer class for models predictions visualization."""
16
16
17
17
18
- from typing import Optional
18
+ from os import PathLike
19
+ from typing import List , Optional , Union
19
20
20
21
import cv2
21
22
import numpy as np
23
+ from IPython .display import display
24
+ from PIL import Image
22
25
23
26
from geti_sdk .data_models .annotation_scene import AnnotationScene
27
+ from geti_sdk .data_models .containers .media_list import MediaList
28
+ from geti_sdk .data_models .media import VideoFrame
24
29
from geti_sdk .data_models .predictions import Prediction
25
30
from geti_sdk .prediction_visualization .shape_drawer import ShapeDrawer
26
31
@@ -44,8 +49,6 @@ def __init__(
44
49
show_confidence : bool = True ,
45
50
show_count : bool = False ,
46
51
is_one_label : bool = False ,
47
- delay : Optional [int ] = None ,
48
- output : Optional [str ] = None ,
49
52
) -> None :
50
53
"""
51
54
Initialize the Visualizer.
@@ -55,19 +58,12 @@ def __init__(
55
58
:param show_confidence: Show confidence on the output image
56
59
:param show_count: Show count of the shapes on the output image
57
60
:param is_one_label: Show only one label on the output image
58
- :param delay: Delay time for the output image
59
- :param output: Path to save the output image
60
61
"""
61
62
self .window_name = "Window" if window_name is None else window_name
62
63
self .shape_drawer = ShapeDrawer (
63
64
show_count , is_one_label , show_labels , show_confidence
64
65
)
65
66
66
- self .delay = delay
67
- if delay is None :
68
- self .delay = 1
69
- self .output = output
70
-
71
67
def draw (
72
68
self ,
73
69
image : np .ndarray ,
@@ -90,7 +86,7 @@ def draw(
90
86
if confidence_threshold is not None :
91
87
annotation = annotation .filter_by_confidence (confidence_threshold )
92
88
result = self .shape_drawer .draw (
93
- image , annotation , labels = [], fill_shapes = fill_shapes
89
+ image . copy () , annotation , labels = [], fill_shapes = fill_shapes
94
90
)
95
91
return result
96
92
@@ -140,7 +136,53 @@ def explain_label(
140
136
result = self .draw (result , filtered_prediction , fill_shapes = False )
141
137
return result
142
138
143
- def show (self , image : np .ndarray ) -> None :
139
+ @staticmethod
140
+ def save_image (image : np .ndarray , output_path : PathLike ) -> None :
141
+ """
142
+ Save the image to the output path.
143
+
144
+ :param image: Image in RGB format to be saved
145
+ :param output_path: Path to save the image
146
+ """
147
+ bgr_image = cv2 .cvtColor (image , cv2 .COLOR_RGB2BGR )
148
+ cv2 .imwrite (output_path , bgr_image )
149
+
150
+ @staticmethod
151
+ def save_video (
152
+ video_frames : MediaList [VideoFrame ],
153
+ annotation_scenes : List [Union [AnnotationScene , Prediction ]],
154
+ output_path : PathLike ,
155
+ fps : float = 1 ,
156
+ ) -> None :
157
+ """
158
+ Save the video to the output path.
159
+
160
+ :param video_frames: List of video frames
161
+ :param annotation_scenes: List of annotation scenes to be drawn on the video frames
162
+ :param output_path: Path to save the image
163
+ """
164
+ out_writer = cv2 .VideoWriter (
165
+ filename = f"{ output_path } " ,
166
+ fourcc = cv2 .VideoWriter_fourcc ("M" , "J" , "P" , "G" ),
167
+ fps = fps ,
168
+ frameSize = (
169
+ video_frames [0 ].media_information .width ,
170
+ video_frames [0 ].media_information .height ,
171
+ ),
172
+ )
173
+ for frame , annotation in zip (video_frames , annotation_scenes ):
174
+ out_writer .write (frame )
175
+
176
+ @staticmethod
177
+ def show_in_notebook (image : np .ndarray ) -> None :
178
+ """
179
+ Show the image in the Jupyter notebook.
180
+
181
+ :param image: Image to be shown in RGB format
182
+ """
183
+ display (Image .fromarray (image ))
184
+
185
+ def show_window (self , image : np .ndarray ) -> None :
144
186
"""
145
187
Show result image.
146
188
@@ -149,6 +191,6 @@ def show(self, image: np.ndarray) -> None:
149
191
image_bgr = cv2 .cvtColor (image , cv2 .COLOR_RGB2BGR )
150
192
cv2 .imshow (self .window_name , image_bgr )
151
193
152
- def is_quit (self ) -> bool :
194
+ def is_quit (self , delay : int = 1 ) -> bool :
153
195
"""Check user wish to quit."""
154
- return ord ("q" ) == cv2 .waitKey (self . delay )
196
+ return ord ("q" ) == cv2 .waitKey (delay )
0 commit comments