12
12
from detectionmetrics .datasets import dataset as dm_dataset
13
13
from detectionmetrics .models import model as dm_model
14
14
from detectionmetrics .models import torch_model_utils as tmu
15
- import detectionmetrics .utils .conversion as uc
16
- import detectionmetrics .utils .io as uio
17
15
import detectionmetrics .utils .lidar as ul
18
16
import detectionmetrics .utils .metrics as um
19
17
@@ -224,11 +222,24 @@ def __init__(self, model_fname: str, model_cfg: str, ontology_fname: str):
224
222
)
225
223
]
226
224
227
- self .transform_input += [
228
- transforms .ToImage (),
229
- transforms .ToDtype (torch .float32 , scale = True ),
230
- ]
231
- self .transform_label += [transforms .ToImage (), transforms .ToDtype (torch .int64 )]
225
+ try :
226
+ self .transform_input += [
227
+ transforms .ToImage (),
228
+ transforms .ToDtype (torch .float32 , scale = True ),
229
+ ]
230
+ self .transform_label += [
231
+ transforms .ToImage (),
232
+ transforms .ToDtype (torch .int64 ),
233
+ ]
234
+ except AttributeError : # adapt for older versions of torchvision transforms v2
235
+ self .transform_input += [
236
+ transforms .ToImageTensor (),
237
+ transforms .ConvertDtype (torch .float32 ),
238
+ ]
239
+ self .transform_label += [
240
+ transforms .ToImageTensor (),
241
+ transforms .ToDtype (torch .int64 ),
242
+ ]
232
243
233
244
if "normalization" in self .model_cfg :
234
245
self .transform_input += [
@@ -311,7 +322,7 @@ def eval(
311
322
# Init metrics
312
323
results = {}
313
324
iou = um .IoU (self .n_classes )
314
- acc = um .Accuracy (self .n_classes )
325
+ cm = um .ConfusionMatrix (self .n_classes )
315
326
316
327
# Evaluation loop
317
328
with torch .no_grad ():
@@ -335,13 +346,13 @@ def eval(
335
346
if lut_ontology is not None :
336
347
label = lut_ontology [label ]
337
348
338
- # Prepare data and update accuracy
349
+ # Prepare data and update confusion matrix
339
350
label = label .squeeze (dim = 1 ).cpu ()
340
351
pred = torch .argmax (pred , axis = 1 ).cpu ()
341
352
if valid_mask is not None :
342
353
valid_mask = valid_mask .squeeze (dim = 1 ).cpu ()
343
354
344
- acc .update (
355
+ cm .update (
345
356
pred .numpy (),
346
357
label .numpy (),
347
358
valid_mask .numpy () if valid_mask is not None else None ,
@@ -363,7 +374,7 @@ def eval(
363
374
364
375
# Get metrics results
365
376
iou_per_class , iou = iou .compute ()
366
- acc_per_class , acc = acc . compute ()
377
+ acc_per_class , acc = cm . get_accuracy ()
367
378
iou_per_class = [float (n ) for n in iou_per_class ]
368
379
acc_per_class = [float (n ) for n in acc_per_class ]
369
380
@@ -526,7 +537,7 @@ def eval(
526
537
527
538
# Init metrics
528
539
iou = um .IoU (self .n_classes )
529
- acc = um .Accuracy (self .n_classes )
540
+ cm = um .ConfusionMatrix (self .n_classes )
530
541
531
542
# Evaluation loop
532
543
results = {}
@@ -590,13 +601,13 @@ def eval(
590
601
if lut_ontology is not None :
591
602
label = lut_ontology [label ]
592
603
593
- # Prepare data and update accuracy
604
+ # Prepare data and update confusion matrix
594
605
label = label .cpu ().unsqueeze (0 )
595
606
pred = self .transform_output (pred ).cpu ().unsqueeze (0 ).to (torch .int64 )
596
607
if valid_mask is not None :
597
608
valid_mask = valid_mask .cpu ().unsqueeze (0 )
598
609
599
- acc .update (
610
+ cm .update (
600
611
pred .numpy (),
601
612
label .numpy (),
602
613
valid_mask .numpy () if valid_mask is not None else None ,
@@ -618,7 +629,7 @@ def eval(
618
629
619
630
# Get metrics results
620
631
iou_per_class , iou = iou .compute ()
621
- acc_per_class , acc = acc . compute ()
632
+ acc_per_class , acc = cm . get_accuracy ()
622
633
iou_per_class = [float (n ) for n in iou_per_class ]
623
634
acc_per_class = [float (n ) for n in acc_per_class ]
624
635
0 commit comments