Skip to content

Commit e6a4752

Browse files
zina-csl-bat
andauthored
Update aggregator.py (#2995)
### Changes Added an error message ### Reason for changes send warning message to avoid Inconsistencies arise when the dataset size is less than the provided or default 'subset_size'. ### Related tickets Closes: #2562 I had an inquiry: I noticed that subset_size is sometimes put as 100, or 300, or specified in the advanced parameters. Should a default be used here, or could you point me to where I can find the correct subset_size to be imported? --------- Co-authored-by: Liubov Talamanova <piccione-mail@yandex.ru>
1 parent 5619afb commit e6a4752

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

nncf/common/tensor_statistics/aggregator.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from nncf.common import factory
1818
from nncf.common.graph.graph import NNCFGraph
1919
from nncf.common.graph.transformations.layout import TransformationLayout
20+
from nncf.common.logging import nncf_logger
2021
from nncf.common.logging.track_progress import track
2122
from nncf.common.tensor import NNCFTensor
2223
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
@@ -68,9 +69,8 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
6869
transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics)
6970
model_with_outputs: TModel = model_transformer.transform(transformation_layout)
7071
engine = factory.EngineFactory.create(model_with_outputs)
71-
7272
iterations_number = self._get_iterations_number()
73-
empty_statistics = True
73+
processed_samples = 0
7474
for input_data in track( # type: ignore
7575
islice(self.dataset.get_inference_data(), iterations_number),
7676
total=iterations_number,
@@ -79,9 +79,14 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
7979
outputs = engine.infer(input_data)
8080
processed_outputs = self._process_outputs(outputs)
8181
self._register_statistics(processed_outputs, merged_statistics)
82-
empty_statistics = False
83-
if empty_statistics:
82+
processed_samples += 1
83+
if processed_samples == 0:
8484
raise nncf.ValidationError(EMPTY_DATASET_ERROR)
85+
if self.stat_subset_size is not None and self.stat_subset_size > processed_samples:
86+
nncf_logger.warning(
87+
f"Dataset contains only {processed_samples} samples, "
88+
f"smaller than the requested subset size {self.stat_subset_size}."
89+
)
8590

8691
def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None:
8792
"""

0 commit comments

Comments
 (0)