Skip to content

Commit 29d6cc1

Browse files
authored
Merge pull request #253 from JdeRobot/dph/confusion_matrix
Add confusion matrix
2 parents 493926b + 6d254c6 commit 29d6cc1

17 files changed

+203
-159
lines changed

detectionmetrics/cli/batch.py

+26-21
Original file line numberDiff line numberDiff line change
@@ -80,27 +80,32 @@ def batch(command, jobs_cfg):
8080
pbar.set_description(f"Processing {job_id}")
8181

8282
ctx = click.get_current_context()
83-
result = ctx.invoke(
84-
cli_registry[command],
85-
task=jobs_cfg["task"],
86-
input_type=jobs_cfg["input_type"],
87-
model_format=model_cfg["format"],
88-
model=model_cfg["path"],
89-
model_ontology=model_cfg["ontology"],
90-
model_cfg=model_cfg["cfg"],
91-
dataset_format=dataset_cfg["format"],
92-
dataset_fname=dataset_cfg.get("fname", None),
93-
dataset_dir=dataset_cfg.get("dir", None),
94-
train_dataset_dir=dataset_cfg.get("train_dir", None),
95-
val_dataset_dir=dataset_cfg.get("val_dir", None),
96-
test_dataset_dir=dataset_cfg.get("test_dir", None),
97-
data_suffix=dataset_cfg.get("data_suffix", None),
98-
label_suffix=dataset_cfg.get("label_suffix", None),
99-
dataset_ontology=dataset_cfg.get("ontology", None),
100-
split=dataset_cfg["split"],
101-
ontology_translation=jobs_cfg.get("ontology_translation", None),
102-
out_fname=job_out_fname,
103-
)
83+
try:
84+
result = ctx.invoke(
85+
cli_registry[command],
86+
task=jobs_cfg["task"],
87+
input_type=jobs_cfg["input_type"],
88+
model_format=model_cfg["format"],
89+
model=model_cfg["path"],
90+
model_ontology=model_cfg["ontology"],
91+
model_cfg=model_cfg["cfg"],
92+
dataset_format=dataset_cfg["format"],
93+
dataset_fname=dataset_cfg.get("fname", None),
94+
dataset_dir=dataset_cfg.get("dir", None),
95+
split_dir=dataset_cfg.get("split_dir", None),
96+
train_dataset_dir=dataset_cfg.get("train_dir", None),
97+
val_dataset_dir=dataset_cfg.get("val_dir", None),
98+
test_dataset_dir=dataset_cfg.get("test_dir", None),
99+
data_suffix=dataset_cfg.get("data_suffix", None),
100+
label_suffix=dataset_cfg.get("label_suffix", None),
101+
dataset_ontology=dataset_cfg.get("ontology", None),
102+
split=dataset_cfg["split"],
103+
ontology_translation=jobs_cfg.get("ontology_translation", None),
104+
out_fname=job_out_fname,
105+
)
106+
except Exception as e:
107+
print(f"Error processing job {job_id}: {e}")
108+
continue
104109

105110
# We assume that the command returns the results as a pandas DataFrame
106111
result["job_id"] = job_id

detectionmetrics/cli/evaluate.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def get_dataset(
2121
dataset_format,
2222
dataset_fname,
2323
dataset_dir,
24+
split_dir,
2425
train_dataset_dir,
2526
val_dataset_dir,
2627
test_dataset_dir,
@@ -33,8 +34,13 @@ def get_dataset(
3334
if dataset_format == "gaia" and dataset_fname is None:
3435
raise ValueError("--dataset is required for 'gaia' format")
3536

36-
elif dataset_format == "rellis3d" and dataset_dir is None:
37-
raise ValueError("--dataset_dir is required for 'rellis3d' format")
37+
elif dataset_format == "rellis3d":
38+
if dataset_dir is None:
39+
raise ValueError("--dataset_dir is required for 'rellis3d' format")
40+
if split_dir is None:
41+
raise ValueError("--split_dir is required for 'rellis3d' format")
42+
if ontology is None:
43+
raise ValueError("--dataset_ontology is required for 'rellis3d' format")
3844

3945
elif dataset_format in ["goose", "generic"]:
4046
if split == "train" and train_dataset_dir is None:
@@ -62,7 +68,11 @@ def get_dataset(
6268
if dataset_format == "gaia":
6369
dataset_args = {"dataset_fname": dataset_fname}
6470
elif dataset_format == "rellis3d":
65-
dataset_args = {"dataset_dir": dataset_dir}
71+
dataset_args = {
72+
"dataset_dir": dataset_dir,
73+
"split_dir": split_dir,
74+
"ontology_fname": ontology,
75+
}
6676
elif dataset_format == "goose":
6777
dataset_args = {
6878
"train_dataset_dir": train_dataset_dir,
@@ -99,7 +109,9 @@ def get_dataset(
99109
# model
100110
@click.option(
101111
"--model_format",
102-
type=click.Choice(["torch", "tensorflow"], case_sensitive=False),
112+
type=click.Choice(
113+
["torch", "tensorflow", "tensorflow_explicit"], case_sensitive=False
114+
),
103115
show_default=True,
104116
default="torch",
105117
help="Trained model format",
@@ -140,6 +152,11 @@ def get_dataset(
140152
type=click.Path(exists=True, file_okay=False, dir_okay=True),
141153
help="Dataset directory (used for 'Rellis3D' format)",
142154
)
155+
@click.option(
156+
"--split_dir",
157+
type=click.Path(exists=True, file_okay=False, dir_okay=True),
158+
help="Directory containing .lst split files (used for 'Rellis3D' format)",
159+
)
143160
@click.option(
144161
"--train_dataset_dir",
145162
type=click.Path(exists=True, file_okay=False, dir_okay=True),
@@ -168,7 +185,7 @@ def get_dataset(
168185
@click.option(
169186
"--dataset_ontology",
170187
type=click.Path(exists=True, dir_okay=False),
171-
help="JSON file containing dataset ontology (used for 'Generic' format)",
188+
help="JSON containing dataset ontology (used for 'Generic' and 'Rellis3D' formats)",
172189
)
173190
@click.option(
174191
"--split",
@@ -199,6 +216,7 @@ def evaluate(
199216
dataset_format,
200217
dataset_fname,
201218
dataset_dir,
219+
split_dir,
202220
train_dataset_dir,
203221
val_dataset_dir,
204222
test_dataset_dir,
@@ -218,6 +236,7 @@ def evaluate(
218236
dataset_format,
219237
dataset_fname,
220238
dataset_dir,
239+
split_dir,
221240
train_dataset_dir,
222241
val_dataset_dir,
223242
test_dataset_dir,

detectionmetrics/datasets/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@
2323
"generic_lidar_segmentation": GenericLiDARSegmentationDataset,
2424
"goose_image_segmentation": GOOSEImageSegmentationDataset,
2525
"goose_lidar_segmentation": GOOSELiDARSegmentationDataset,
26-
"rellis_image_segmentation": Rellis3DImageSegmentationDataset,
27-
"rellis_lidar_segmentation": Rellis3DLiDARSegmentationDataset,
26+
"rellis3d_image_segmentation": Rellis3DImageSegmentationDataset,
27+
"rellis3d_lidar_segmentation": Rellis3DLiDARSegmentationDataset,
2828
}

detectionmetrics/models/__init__.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1-
from detectionmetrics.models.torch import (
2-
TorchImageSegmentationModel,
3-
TorchLiDARSegmentationModel,
4-
)
5-
from detectionmetrics.models.tensorflow import (
6-
TensorflowImageSegmentationModel,
7-
)
8-
9-
10-
REGISTRY = {
11-
"torch_image_segmentation": TorchImageSegmentationModel,
12-
"torch_lidar_segmentation": TorchLiDARSegmentationModel,
13-
"tensorflow_image_segmentation": TensorflowImageSegmentationModel,
14-
}
1+
REGISTRY = {}
2+
3+
try:
4+
from detectionmetrics.models.torch import (
5+
TorchImageSegmentationModel,
6+
TorchLiDARSegmentationModel,
7+
)
8+
9+
REGISTRY["torch_image_segmentation"] = TorchImageSegmentationModel
10+
REGISTRY["torch_lidar_segmentation"] = TorchLiDARSegmentationModel
11+
except ImportError:
12+
print("Torch not available")
13+
14+
try:
15+
from detectionmetrics.models.tensorflow import TensorflowImageSegmentationModel
16+
17+
REGISTRY["tensorflow_image_segmentation"] = TensorflowImageSegmentationModel
18+
except ImportError:
19+
print("Tensorflow not available")
20+
21+
if not REGISTRY:
22+
raise Exception("No valid deep learning framework found")

detectionmetrics/models/tensorflow.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def inference(self, image: Image.Image) -> Image.Image:
157157
# TODO: check if this is consistent across different models
158158
result = self.model.signatures["serving_default"](tensor)
159159
if isinstance(result, dict):
160-
result = result["output_0"]
160+
result = list(result.values())[0]
161161

162162
return self.t_out(result)
163163

@@ -195,18 +195,19 @@ def eval(
195195
# Init metrics
196196
results = {}
197197
iou = um.IoU(self.n_classes)
198-
acc = um.Accuracy(self.n_classes)
198+
cm = um.ConfusionMatrix(self.n_classes)
199199

200200
# Evaluation loop
201201
pbar = tqdm(dataset.dataset)
202202
for image, label in pbar:
203+
# TODO: check if this is consistent across different models
203204
pred = self.model.signatures["serving_default"](image)
204205
if isinstance(pred, dict):
205-
pred = pred["output_0"]
206+
pred = list(pred.values())[0]
206207

207208
label = tf.squeeze(label, axis=3)
208209
pred = tf.argmax(pred, axis=3)
209-
acc.update(pred.numpy(), label.numpy())
210+
cm.update(pred.numpy(), label.numpy())
210211

211212
pred = tf.one_hot(pred, self.n_classes)
212213
pred = tf.transpose(pred, perm=[0, 3, 1, 2])
@@ -218,7 +219,7 @@ def eval(
218219

219220
# Get metrics results
220221
iou_per_class, iou = iou.compute()
221-
acc_per_class, acc = acc.compute()
222+
acc_per_class, acc = cm.get_accuracy()
222223
iou_per_class = [float(n) for n in iou_per_class]
223224
acc_per_class = [float(n) for n in acc_per_class]
224225

detectionmetrics/models/torch.py

+26-15
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from detectionmetrics.datasets import dataset as dm_dataset
1313
from detectionmetrics.models import model as dm_model
1414
from detectionmetrics.models import torch_model_utils as tmu
15-
import detectionmetrics.utils.conversion as uc
16-
import detectionmetrics.utils.io as uio
1715
import detectionmetrics.utils.lidar as ul
1816
import detectionmetrics.utils.metrics as um
1917

@@ -224,11 +222,24 @@ def __init__(self, model_fname: str, model_cfg: str, ontology_fname: str):
224222
)
225223
]
226224

227-
self.transform_input += [
228-
transforms.ToImage(),
229-
transforms.ToDtype(torch.float32, scale=True),
230-
]
231-
self.transform_label += [transforms.ToImage(), transforms.ToDtype(torch.int64)]
225+
try:
226+
self.transform_input += [
227+
transforms.ToImage(),
228+
transforms.ToDtype(torch.float32, scale=True),
229+
]
230+
self.transform_label += [
231+
transforms.ToImage(),
232+
transforms.ToDtype(torch.int64),
233+
]
234+
except AttributeError: # adapt for older versions of torchvision transforms v2
235+
self.transform_input += [
236+
transforms.ToImageTensor(),
237+
transforms.ConvertDtype(torch.float32),
238+
]
239+
self.transform_label += [
240+
transforms.ToImageTensor(),
241+
transforms.ToDtype(torch.int64),
242+
]
232243

233244
if "normalization" in self.model_cfg:
234245
self.transform_input += [
@@ -311,7 +322,7 @@ def eval(
311322
# Init metrics
312323
results = {}
313324
iou = um.IoU(self.n_classes)
314-
acc = um.Accuracy(self.n_classes)
325+
cm = um.ConfusionMatrix(self.n_classes)
315326

316327
# Evaluation loop
317328
with torch.no_grad():
@@ -335,13 +346,13 @@ def eval(
335346
if lut_ontology is not None:
336347
label = lut_ontology[label]
337348

338-
# Prepare data and update accuracy
349+
# Prepare data and update confusion matrix
339350
label = label.squeeze(dim=1).cpu()
340351
pred = torch.argmax(pred, axis=1).cpu()
341352
if valid_mask is not None:
342353
valid_mask = valid_mask.squeeze(dim=1).cpu()
343354

344-
acc.update(
355+
cm.update(
345356
pred.numpy(),
346357
label.numpy(),
347358
valid_mask.numpy() if valid_mask is not None else None,
@@ -363,7 +374,7 @@ def eval(
363374

364375
# Get metrics results
365376
iou_per_class, iou = iou.compute()
366-
acc_per_class, acc = acc.compute()
377+
acc_per_class, acc = cm.get_accuracy()
367378
iou_per_class = [float(n) for n in iou_per_class]
368379
acc_per_class = [float(n) for n in acc_per_class]
369380

@@ -526,7 +537,7 @@ def eval(
526537

527538
# Init metrics
528539
iou = um.IoU(self.n_classes)
529-
acc = um.Accuracy(self.n_classes)
540+
cm = um.ConfusionMatrix(self.n_classes)
530541

531542
# Evaluation loop
532543
results = {}
@@ -590,13 +601,13 @@ def eval(
590601
if lut_ontology is not None:
591602
label = lut_ontology[label]
592603

593-
# Prepare data and update accuracy
604+
# Prepare data and update confusion matrix
594605
label = label.cpu().unsqueeze(0)
595606
pred = self.transform_output(pred).cpu().unsqueeze(0).to(torch.int64)
596607
if valid_mask is not None:
597608
valid_mask = valid_mask.cpu().unsqueeze(0)
598609

599-
acc.update(
610+
cm.update(
600611
pred.numpy(),
601612
label.numpy(),
602613
valid_mask.numpy() if valid_mask is not None else None,
@@ -618,7 +629,7 @@ def eval(
618629

619630
# Get metrics results
620631
iou_per_class, iou = iou.compute()
621-
acc_per_class, acc = acc.compute()
632+
acc_per_class, acc = cm.get_accuracy()
622633
iou_per_class = [float(n) for n in iou_per_class]
623634
acc_per_class = [float(n) for n in acc_per_class]
624635

0 commit comments

Comments
 (0)