|
3 | 3 | # Copyright (C) 2024 Intel Corporation
|
4 | 4 | # SPDX-License-Identifier: Apache-2.0
|
5 | 5 |
|
| 6 | +from itertools import starmap |
| 7 | +from typing import Union |
| 8 | + |
| 9 | +import cv2 |
6 | 10 | from PIL import Image
|
7 | 11 |
|
8 | 12 | from model_api.models.result import AnomalyResult
|
9 | 13 | from model_api.visualizer.layout import Flatten, Layout
|
10 |
| -from model_api.visualizer.primitive import Overlay |
| 14 | +from model_api.visualizer.primitive import BoundingBox, Label, Overlay, Polygon |
11 | 15 |
|
12 | 16 | from .scene import Scene
|
13 | 17 |
|
14 | 18 |
|
15 | 19 | class AnomalyScene(Scene):
|
16 | 20 | """Anomaly Scene."""
|
17 | 21 |
|
18 |
| - def __init__(self, image: Image, result: AnomalyResult) -> None: |
19 |
| - self.image = image |
20 |
| - self.result = result |
| 22 | + def __init__(self, image: Image, result: AnomalyResult, layout: Union[Layout, None] = None) -> None: |
| 23 | + super().__init__( |
| 24 | + base=image, |
| 25 | + overlay=self._get_overlays(result), |
| 26 | + bounding_box=self._get_bounding_boxes(result), |
| 27 | + label=self._get_labels(result), |
| 28 | + polygon=self._get_polygons(result), |
| 29 | + layout=layout, |
| 30 | + ) |
| 31 | + |
| 32 | + def _get_overlays(self, result: AnomalyResult) -> list[Overlay]: |
| 33 | + if result.anomaly_map is not None: |
| 34 | + anomaly_map = cv2.cvtColor(result.anomaly_map, cv2.COLOR_BGR2RGB) |
| 35 | + return [Overlay(anomaly_map)] |
| 36 | + return [] |
| 37 | + |
| 38 | + def _get_bounding_boxes(self, result: AnomalyResult) -> list[BoundingBox]: |
| 39 | + if result.pred_boxes is not None: |
| 40 | + return list(starmap(BoundingBox, result.pred_boxes)) |
| 41 | + return [] |
| 42 | + |
| 43 | + def _get_labels(self, result: AnomalyResult) -> list[Label]: |
| 44 | + labels = [] |
| 45 | + if result.pred_label is not None: |
| 46 | + labels.append(Label(result.pred_label)) |
| 47 | + if result.pred_score is not None: |
| 48 | + labels.append(Label(result.pred_score)) |
| 49 | + return labels |
| 50 | + |
| 51 | + def _get_polygons(self, result: AnomalyResult) -> list[Polygon]: |
| 52 | + if result.pred_mask is not None: |
| 53 | + return [Polygon(result.pred_mask)] |
| 54 | + return [] |
21 | 55 |
|
22 | 56 | @property
|
23 | 57 | def default_layout(self) -> Layout:
|
24 |
| - return Flatten(Overlay) |
| 58 | + return Flatten(Overlay, BoundingBox, Label, Polygon) |
0 commit comments