Skip to content

Commit bb7590b

Browse files
authored
Merge pull request #270 from JdeRobot/269-allow-for-en-arbitrary-choice-of-splits-during-evaluation
Allow for an arbitrary choice of splits during evaluation
2 parents 1503929 + af00bbe commit bb7590b

14 files changed

+124
-53
lines changed

detectionmetrics/cli/evaluate.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ def get_dataset(
4646
raise ValueError("--dataset_ontology is required for 'rellis3d' format")
4747

4848
elif dataset_format in ["goose", "generic"]:
49-
if split == "train" and train_dataset_dir is None:
49+
if "train" in split and train_dataset_dir is None:
5050
raise ValueError(
5151
"--train_dataset_dir is required for 'train' split in 'goose' format"
5252
)
53-
elif split == "val" and val_dataset_dir is None:
53+
elif "val" in split and val_dataset_dir is None:
5454
raise ValueError(
5555
"--val_dataset_dir is required for 'val' split in 'goose' format"
5656
)
57-
elif split == "test" and test_dataset_dir is None:
57+
elif "test" in split and test_dataset_dir is None:
5858
raise ValueError(
5959
"--test_dataset_dir is required for 'test' split in 'goose' format"
6060
)
@@ -118,6 +118,17 @@ def get_dataset(
118118
)
119119
return datasets.REGISTRY[dataset_name](**dataset_args)
120120

121+
def parse_split(ctx, param, value):
122+
"""Parse split argument"""
123+
splits = value.split(",")
124+
valid_splits = ["train", "val", "test"]
125+
if not all(split in valid_splits for split in splits):
126+
raise click.BadParameter(
127+
f"Split must be one of {valid_splits} or a comma-separated list of them",
128+
param_hint=value,
129+
)
130+
131+
return splits
121132

122133
@click.command(name="evaluate", help="Evaluate model on dataset")
123134
@click.argument("task", type=click.Choice(["segmentation"], case_sensitive=False))
@@ -219,10 +230,10 @@ def get_dataset(
219230
)
220231
@click.option(
221232
"--split",
222-
type=click.Choice(["train", "val", "test"], case_sensitive=False),
223233
show_default=True,
224234
default="test",
225-
help="Name of the split to be evaluated",
235+
callback=parse_split,
236+
help="Name of the split or splits separated by commas to be evaluated",
226237
)
227238
@click.option(
228239
"--ontology_translation",
@@ -260,6 +271,8 @@ def evaluate(
260271
out_fname,
261272
):
262273
"""Evaluate model on dataset"""
274+
if isinstance(split, str): # if evaluate has been called directly
275+
split = parse_split(None, None, split)
263276

264277
model = get_model(task, input_type, model_format, model, model_ontology, model_cfg)
265278
dataset = get_dataset(

detectionmetrics/datasets/dataset.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,28 @@ def append(self, new_dataset: Self):
4444
:param new_dataset: Dataset to be appended
4545
:type new_dataset: Self
4646
"""
47-
assert self.ontology == new_dataset.ontology, "Ontologies don't match"
47+
if not self.has_label_count:
48+
assert self.ontology == new_dataset.ontology, "Ontologies don't match"
49+
else:
50+
# Check if classes match
51+
assert (
52+
self.ontology.keys() == new_dataset.ontology.keys()
53+
), "Ontologies don't match"
54+
for class_name in self.ontology:
55+
# Check if indices, and RGB values match
56+
assert (
57+
self.ontology[class_name]["idx"]
58+
== new_dataset.ontology[class_name]["idx"]
59+
), "Ontologies don't match"
60+
assert (
61+
self.ontology[class_name]["rgb"]
62+
== new_dataset.ontology[class_name]["rgb"]
63+
), "Ontologies don't match"
64+
65+
# Accumulate label count
66+
self.ontology[class_name]["label_count"] += new_dataset.ontology[
67+
class_name
68+
]["label_count"]
4869

4970
# Global filenames to avoid dealing with each dataset relative location
5071
self.make_fname_global()

detectionmetrics/models/model.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
import os
3-
from typing import Any, Optional, Union
3+
from typing import Any, List, Optional, Union
44

55
import numpy as np
66
import pandas as pd
@@ -67,15 +67,15 @@ def inference(
6767
def eval(
6868
self,
6969
dataset: dm_dataset.SegmentationDataset,
70-
split: str = "all",
70+
split: str | List[str] = "test",
7171
ontology_translation: Optional[str] = None,
7272
) -> pd.DataFrame:
7373
"""Perform evaluation for an image segmentation dataset
7474
7575
:param dataset: Segmentation dataset for which the evaluation will be performed
7676
:type dataset: ImageSegmentationDataset
77-
:param split: Split to be used from the dataset, defaults to "all"
78-
:type split: str, optional
77+
:param split: Split or splits to be used from the dataset, defaults to "test"
78+
:type split: str | List[str], optional
7979
:param ontology_translation: JSON file containing translation between dataset and model output ontologies
8080
:type ontology_translation: str, optional
8181
:return: DataFrame containing evaluation results
@@ -158,15 +158,15 @@ def inference(self, image: Image.Image) -> Image.Image:
158158
def eval(
159159
self,
160160
dataset: dm_dataset.ImageSegmentationDataset,
161-
split: str = "all",
161+
split: str | List[str] = "test",
162162
ontology_translation: Optional[str] = None,
163163
) -> pd.DataFrame:
164164
"""Perform evaluation for an image segmentation dataset
165165
166166
:param dataset: Image segmentation dataset for which the evaluation will be performed
167167
:type dataset: ImageSegmentationDataset
168-
:param split: Split to be used from the dataset, defaults to "all"
169-
:type split: str, optional
168+
:param split: Split or splits to be used from the dataset, defaults to "test"
169+
:type split: str | List[str], optional
170170
:param ontology_translation: JSON file containing translation between dataset and model output ontologies
171171
:type ontology_translation: str, optional
172172
:return: DataFrame containing evaluation results
@@ -215,15 +215,15 @@ def inference(self, points: np.ndarray) -> np.ndarray:
215215
def eval(
216216
self,
217217
dataset: dm_dataset.LiDARSegmentationDataset,
218-
split: str = "all",
218+
split: str | List[str] = "test",
219219
ontology_translation: Optional[str] = None,
220220
) -> pd.DataFrame:
221221
"""Perform evaluation for a LiDAR segmentation dataset
222222
223223
:param dataset: LiDAR segmentation dataset for which the evaluation will be performed
224224
:type dataset: LiDARSegmentationDataset
225-
:param split: Split to be used from the dataset, defaults to "all"
226-
:type split: str, optional
225+
:param split: Split or splits to be used from the dataset, defaults to "test"
226+
:type split: str | List[str], optional
227227
:param ontology_translation: JSON file containing translation between dataset and model output ontologies
228228
:type ontology_translation: str, optional
229229
:return: DataFrame containing evaluation results

detectionmetrics/models/torch.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections import defaultdict
22
import os
33
import time
4-
from typing import Any, Optional, Tuple, Union
4+
from typing import Any, List, Optional, Tuple, Union
55

66
import numpy as np
77
import pandas as pd
@@ -139,20 +139,19 @@ class ImageSegmentationTorchDataset(Dataset):
139139
:type transform: transforms.Compose
140140
:param target_transform: Transformation to be applied to labels
141141
:type target_transform: transforms.Compose
142-
:param split: Split to be used from the dataset, defaults to "all"
143-
:type split: str, optional
142+
:param splits: Splits to be used from the dataset, defaults to ["test"]
143+
:type splits: str, optional
144144
"""
145145

146146
def __init__(
147147
self,
148148
dataset: dm_dataset.ImageSegmentationDataset,
149149
transform: transforms.Compose,
150150
target_transform: transforms.Compose,
151-
split: str = "all",
151+
splits: List[str] = ["test"],
152152
):
153153
# Filter split and make filenames global
154-
if split != "all":
155-
dataset.dataset = dataset.dataset[dataset.dataset["split"] == split]
154+
dataset.dataset = dataset.dataset[dataset.dataset["split"].isin(splits)]
156155
self.dataset = dataset
157156
self.dataset.make_fname_global()
158157

@@ -192,8 +191,8 @@ class LiDARSegmentationTorchDataset(Dataset):
192191
:type preprocess: callable
193192
:param n_classes: Number of classes estimated by the model
194193
:type n_classes: int
195-
:param split: Split to be used from the dataset, defaults to "all"
196-
:type split: str, optional
194+
:param splits: Splits to be used from the dataset, defaults to ["test"]
195+
:type splits: str, optional
197196
"""
198197

199198
def __init__(
@@ -202,11 +201,10 @@ def __init__(
202201
model_cfg: dict,
203202
preprocess: callable,
204203
n_classes: int,
205-
split: str = "all",
204+
splits: str = ["test"],
206205
):
207206
# Filter split and make filenames global
208-
if split != "all":
209-
dataset.dataset = dataset.dataset[dataset.dataset["split"] == split]
207+
dataset.dataset = dataset.dataset[dataset.dataset["split"].isin(splits)]
210208
self.dataset = dataset
211209
self.dataset.make_fname_global()
212210

@@ -366,15 +364,15 @@ def inference(self, image: Image.Image) -> Image.Image:
366364
def eval(
367365
self,
368366
dataset: dm_dataset.ImageSegmentationDataset,
369-
split: str = "all",
367+
split: str | List[str] = "test",
370368
ontology_translation: Optional[str] = None,
371369
) -> pd.DataFrame:
372370
"""Perform evaluation for an image segmentation dataset
373371
374372
:param dataset: Image segmentation dataset for which the evaluation will be performed
375373
:type dataset: ImageSegmentationDataset
376-
:param split: Split to be used from the dataset, defaults to "all"
377-
:type split: str, optional
374+
:param split: Split or splits to be used from the dataset, defaults to "test"
375+
:type split: str | List[str], optional
378376
:param ontology_translation: JSON file containing translation between dataset and model output ontologies
379377
:type ontology_translation: str, optional
380378
:return: DataFrame containing evaluation results
@@ -394,7 +392,7 @@ def eval(
394392
dataset,
395393
transform=self.transform_input,
396394
target_transform=self.transform_label,
397-
split=split,
395+
splits=[split] if isinstance(split, str) else split,
398396
)
399397

400398
dataloader = DataLoader(
@@ -605,15 +603,15 @@ def inference(self, points: np.ndarray) -> np.ndarray:
605603
def eval(
606604
self,
607605
dataset: dm_dataset.LiDARSegmentationDataset,
608-
split: str = "all",
606+
split: str | List[str] = "test",
609607
ontology_translation: Optional[str] = None,
610608
) -> pd.DataFrame:
611609
"""Perform evaluation for a LiDAR segmentation dataset
612610
613611
:param dataset: LiDAR segmentation dataset for which the evaluation will be performed
614612
:type dataset: LiDARSegmentationDataset
615-
:param split: Split to be used from the dataset, defaults to "all"
616-
:type split: str, optional
613+
:param split: Split or splits to be used from the dataset, defaults to "test"
614+
:type split: str | List[str], optional
617615
:param ontology_translation: JSON file containing translation between dataset and model output ontologies
618616
:type ontology_translation: str, optional
619617
:return: DataFrame containing evaluation results
@@ -634,7 +632,7 @@ def eval(
634632
model_cfg=self.model_cfg,
635633
preprocess=self.preprocess,
636634
n_classes=self.n_classes,
637-
split=split,
635+
splits=[split] if isinstance(split, str) else split,
638636
)
639637

640638
# Init metrics
Binary file not shown.
Binary file not shown.
1.35 KB
Binary file not shown.

docs/py_docs/_build/html/detectionmetrics.cli.html

+6
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ <h2>Submodules<a class="headerlink" href="#submodules" title="Link to this headi
8989
<span class="sig-prename descclassname"><span class="pre">detectionmetrics.cli.evaluate.</span></span><span class="sig-name descname"><span class="pre">get_model</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">task</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_type</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model_format</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ontology</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model_cfg</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#detectionmetrics.cli.evaluate.get_model" title="Link to this definition"></a></dt>
9090
<dd></dd></dl>
9191

92+
<dl class="py function">
93+
<dt class="sig sig-object py" id="detectionmetrics.cli.evaluate.parse_split">
94+
<span class="sig-prename descclassname"><span class="pre">detectionmetrics.cli.evaluate.</span></span><span class="sig-name descname"><span class="pre">parse_split</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">ctx</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#detectionmetrics.cli.evaluate.parse_split" title="Link to this definition"></a></dt>
95+
<dd><p>Parse split argument</p>
96+
</dd></dl>
97+
9298
</section>
9399
<section id="module-detectionmetrics.cli">
94100
<span id="module-contents"></span><h2>Module contents<a class="headerlink" href="#module-detectionmetrics.cli" title="Link to this heading"></a></h2>

docs/py_docs/_build/html/detectionmetrics.html

+1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ <h2>Subpackages<a class="headerlink" href="#subpackages" title="Link to this hea
177177
</ul>
178178
</li>
179179
<li class="toctree-l3"><a class="reference internal" href="detectionmetrics.models.html#detectionmetrics.models.tensorflow.get_computational_cost"><code class="docutils literal notranslate"><span class="pre">get_computational_cost()</span></code></a></li>
180+
<li class="toctree-l3"><a class="reference internal" href="detectionmetrics.models.html#detectionmetrics.models.tensorflow.resize_image"><code class="docutils literal notranslate"><span class="pre">resize_image()</span></code></a></li>
180181
</ul>
181182
</li>
182183
<li class="toctree-l2"><a class="reference internal" href="detectionmetrics.models.html#module-detectionmetrics.models.torch">detectionmetrics.models.torch module</a><ul>

0 commit comments

Comments
 (0)