15
15
"""Module implements the InferenceResultsToPredictionConverter class."""
16
16
17
17
import abc
18
+ import logging
18
19
from typing import Any , Dict , List , NamedTuple , Optional , Tuple , Union
19
20
20
21
import cv2
31
32
from geti_sdk .data_models .annotations import Annotation
32
33
from geti_sdk .data_models .containers import LabelList
33
34
from geti_sdk .data_models .enums .domain import Domain
34
- from geti_sdk .data_models .label import ScoredLabel
35
+ from geti_sdk .data_models .label import Label , ScoredLabel
35
36
from geti_sdk .data_models .predictions import Prediction
36
37
from geti_sdk .data_models .shapes import (
37
38
Ellipse ,
48
49
class InferenceResultsToPredictionConverter (metaclass = abc .ABCMeta ):
49
50
"""Interface for the converter"""
50
51
51
- def __init__ (
52
- self , labels : LabelList , configuration : Optional [Dict [str , Any ]] = None
53
- ):
52
+ def __init__ (self , labels : LabelList , configuration : Dict [str , Any ]):
54
53
self .labels = labels .get_non_empty_labels ()
55
54
self .empty_label = labels .get_empty_label ()
56
55
self .configuration = configuration
56
+ self .is_labels_sorted = "label_ids" in configuration
57
+ if self .is_labels_sorted :
58
+ # Make sure the list of labels is sorted according to the order
59
+ # defined in the ModelAPI configuration.
60
+ # - If the 'label_ids' field only contains a single label,
61
+ # it will be typed as string. No need to sort in that case.
62
+ # - Filter out the empty label ID, as it is managed separately by the base converter class.
63
+ ids = configuration ["label_ids" ]
64
+ if not isinstance (ids , str ):
65
+ ids = [
66
+ id_
67
+ for id_ in ids
68
+ if not self .empty_label or id_ != self .empty_label .id
69
+ ]
70
+ self .labels .sort_by_ids (ids )
57
71
58
72
@abc .abstractmethod
59
73
def convert_to_prediction (
@@ -89,9 +103,7 @@ class ClassificationToPredictionConverter(InferenceResultsToPredictionConverter)
89
103
parameters
90
104
"""
91
105
92
- def __init__ (
93
- self , labels : LabelList , configuration : Optional [Dict [str , Any ]] = None
94
- ):
106
+ def __init__ (self , labels : LabelList , configuration : Dict [str , Any ]):
95
107
super ().__init__ (labels , configuration )
96
108
97
109
def convert_to_prediction (
@@ -110,11 +122,18 @@ def convert_to_prediction(
110
122
labels = []
111
123
for label in inference_results .top_labels :
112
124
label_idx , label_name , label_prob = label
113
- # label_idx does not necessarily match the label index in the project
114
- # labels. Therefore, we map the label by name instead.
115
- labels .append (
116
- self .labels .create_scored_label (id_or_name = label_name , score = label_prob )
117
- )
125
+ if self .is_labels_sorted :
126
+ scored_label = ScoredLabel .from_label (
127
+ label = self .labels [label_idx ], probability = label_prob
128
+ )
129
+ else :
130
+ # label_idx does not necessarily match the label index in the project
131
+ # labels. Therefore, we map the label by name instead.
132
+ _label = self ._get_label_by_prediction_name (name = label_name )
133
+ scored_label = ScoredLabel .from_label (
134
+ label = _label , probability = label_prob
135
+ )
136
+ labels .append (scored_label )
118
137
119
138
if not labels and self .empty_label :
120
139
labels = [ScoredLabel .from_label (self .empty_label , probability = 0 )]
@@ -153,6 +172,27 @@ def convert_saliency_map(
153
172
for i , label in enumerate (self .labels .get_non_empty_labels ())
154
173
}
155
174
175
+ def _get_label_by_prediction_name (self , name : str ) -> Label :
176
+ """
177
+ Get a Label object by its predicted name.
178
+
179
+ :param name: predicted name of the label
180
+ :return: Label corresponding to the name
181
+ :raises KeyError: if the label is not found in the LabelList
182
+ """
183
+ try :
184
+ return self .labels .get_by_name (name = name )
185
+ except KeyError :
186
+ # If the label is not found, we try to find it by legacy name (replacing spaces with underscores)
187
+ for label in self .labels :
188
+ legacy_name = label .name .replace (" " , "_" )
189
+ if legacy_name == name :
190
+ logging .warning (
191
+ f"Found label `{ label .name } ` using its legacy name `{ legacy_name } `."
192
+ )
193
+ return label
194
+ raise KeyError (f"Label named `{ name } ` was not found in the LabelList" )
195
+
156
196
157
197
class DetectionToPredictionConverter (InferenceResultsToPredictionConverter ):
158
198
"""
@@ -162,27 +202,14 @@ class DetectionToPredictionConverter(InferenceResultsToPredictionConverter):
162
202
:param configuration: optional model configuration setting
163
203
"""
164
204
165
- def __init__ (
166
- self , labels : LabelList , configuration : Optional [Dict [str , Any ]] = None
167
- ):
205
+ def __init__ (self , labels : LabelList , configuration : Dict [str , Any ]):
168
206
super ().__init__ (labels , configuration )
169
207
self .use_ellipse_shapes = False
170
208
self .confidence_threshold = 0.0
171
- if configuration is not None :
172
- if "use_ellipse_shapes" in configuration :
173
- self .use_ellipse_shapes = configuration ["use_ellipse_shapes" ]
174
- if "confidence_threshold" in configuration :
175
- self .confidence_threshold = configuration ["confidence_threshold" ]
176
- if "label_ids" in configuration :
177
- # Make sure the list of labels is sorted according to the order
178
- # defined in the ModelAPI configuration.
179
- # - If the 'label_ids' field only contains a single label,
180
- # it will be typed as string. No need to sort in that case.
181
- # - Filter out the empty label ID, as it is managed separately by the base converter class.
182
- ids = configuration ["label_ids" ]
183
- if not isinstance (ids , str ):
184
- ids = [id_ for id_ in ids if id_ != self .empty_label .id ]
185
- self .labels .sort_by_ids (ids )
209
+ if "use_ellipse_shapes" in configuration :
210
+ self .use_ellipse_shapes = configuration ["use_ellipse_shapes" ]
211
+ if "confidence_threshold" in configuration :
212
+ self .confidence_threshold = configuration ["confidence_threshold" ]
186
213
187
214
def _detection2array (self , detections : List [Detection ]) -> np .ndarray :
188
215
"""
@@ -468,9 +495,7 @@ class SegmentationToPredictionConverter(InferenceResultsToPredictionConverter):
468
495
:param configuration: optional model configuration setting
469
496
"""
470
497
471
- def __init__ (
472
- self , labels : LabelList , configuration : Optional [Dict [str , Any ]] = None
473
- ):
498
+ def __init__ (self , labels : LabelList , configuration : Dict [str , Any ]):
474
499
super ().__init__ (labels , configuration )
475
500
# NB: index=0 is reserved for the background label
476
501
self .label_map = dict (enumerate (self .labels , 1 ))
@@ -518,9 +543,7 @@ class AnomalyToPredictionConverter(InferenceResultsToPredictionConverter):
518
543
:param configuration: optional model configuration setting
519
544
"""
520
545
521
- def __init__ (
522
- self , labels : LabelList , configuration : Optional [Dict [str , Any ]] = None
523
- ):
546
+ def __init__ (self , labels : LabelList , configuration : Dict [str , Any ]):
524
547
super ().__init__ (labels , configuration )
525
548
self .normal_label = next (
526
549
label for label in self .labels if not label .is_anomalous
@@ -629,14 +652,14 @@ class ConverterFactory:
629
652
def create_converter (
630
653
labels : LabelList ,
631
654
domain : Domain ,
632
- configuration : Optional [ Dict [str , Any ]] = None ,
655
+ configuration : Dict [str , Any ],
633
656
) -> InferenceResultsToPredictionConverter :
634
657
"""
635
- Create the appropriate inferencer object according to the model's task.
658
+ Create the appropriate inference converter object according to the model's task.
636
659
637
- :param label_schema : The label schema containing the label info of the task.
660
+ :param labels : The labels of the model
638
661
:param domain: The domain to which the converter applies
639
- :param configuration: Optional configuration for the converter. Defaults to None.
662
+ :param configuration: configuration for the converter
640
663
:return: The created inference result to prediction converter.
641
664
:raises ValueError: If the task type cannot be determined from the label schema.
642
665
"""
0 commit comments