13
13
from pathlib import Path
14
14
from typing import Any , Dict , Tuple
15
15
16
- import numpy as np
17
16
import onnx
18
17
import onnxruntime
19
18
import torch
20
- from tqdm import tqdm
19
+ from rich . progress import track
21
20
from ultralytics .cfg import get_cfg
22
21
from ultralytics .data .converter import coco80_to_coco91_class
23
22
from ultralytics .data .utils import check_det_dataset
24
- from ultralytics .engine .validator import BaseValidator as Validator
25
23
from ultralytics .models .yolo import YOLO
24
+ from ultralytics .models .yolo .segment .val import SegmentationValidator
26
25
from ultralytics .utils import DATASETS_DIR
27
26
from ultralytics .utils import DEFAULT_CFG
28
27
from ultralytics .utils import ops
29
28
from ultralytics .utils .metrics import ConfusionMatrix
30
29
31
30
import nncf
32
31
32
+ MODEL_NAME = "yolov8n-seg"
33
+
33
34
ROOT = Path (__file__ ).parent .resolve ()
34
35
35
36
36
37
def validate (
37
- model : onnx .ModelProto , data_loader : torch .utils .data .DataLoader , validator : Validator , num_samples : int = None
38
+ model : onnx .ModelProto ,
39
+ data_loader : torch .utils .data .DataLoader ,
40
+ validator : SegmentationValidator ,
41
+ num_samples : int = None ,
38
42
) -> Tuple [Dict , int , int ]:
39
43
validator .seen = 0
40
44
validator .jdict = []
@@ -49,7 +53,7 @@ def validate(
49
53
output_names = [output .name for output in session .get_outputs ()]
50
54
num_outputs = len (output_names )
51
55
52
- for batch_i , batch in enumerate (data_loader ):
56
+ for batch_i , batch in enumerate (track ( data_loader , description = "Validating" ) ):
53
57
if num_samples is not None and batch_i == num_samples :
54
58
break
55
59
batch = validator .preprocess (batch )
@@ -71,7 +75,7 @@ def validate(
71
75
return stats , validator .seen , validator .nt_per_class .sum ()
72
76
73
77
74
- def print_statistics (stats : np . ndarray , total_images : int , total_objects : int ) -> None :
78
+ def print_statistics (stats : Dict [ str , float ] , total_images : int , total_objects : int ) -> None :
75
79
print ("Metrics(Box):" )
76
80
mp , mr , map50 , mean_ap = (
77
81
stats ["metrics/precision(B)" ],
@@ -84,38 +88,35 @@ def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -
84
88
pf = "%20s" + "%12i" * 2 + "%12.3g" * 4 # print format
85
89
print (pf % ("all" , total_images , total_objects , mp , mr , map50 , mean_ap ))
86
90
87
- # print the mask metrics for segmentation
88
- if "metrics/precision(M)" in stats :
89
- print ("Metrics(Mask):" )
90
- s_mp , s_mr , s_map50 , s_mean_ap = (
91
- stats ["metrics/precision(M)" ],
92
- stats ["metrics/recall(M)" ],
93
- stats ["metrics/mAP50(M)" ],
94
- stats ["metrics/mAP50-95(M)" ],
95
- )
96
- # Print results
97
- s = ("%20s" + "%12s" * 6 ) % ("Class" , "Images" , "Labels" , "Precision" , "Recall" , "mAP@.5" , "mAP@.5:.95" )
98
- print (s )
99
- pf = "%20s" + "%12i" * 2 + "%12.3g" * 4 # print format
100
- print (pf % ("all" , total_images , total_objects , s_mp , s_mr , s_map50 , s_mean_ap ))
101
-
102
-
103
- def prepare_validation (model : YOLO , args : Any ) -> Tuple [Validator , torch .utils .data .DataLoader ]:
104
- validator = model .task_map [model .task ]["validator" ](args = args )
105
- validator .data = check_det_dataset (args .data )
106
- validator .stride = 32
91
+ print ("Metrics(Mask):" )
92
+ s_mp , s_mr , s_map50 , s_mean_ap = (
93
+ stats ["metrics/precision(M)" ],
94
+ stats ["metrics/recall(M)" ],
95
+ stats ["metrics/mAP50(M)" ],
96
+ stats ["metrics/mAP50-95(M)" ],
97
+ )
98
+ # Print results
99
+ s = ("%20s" + "%12s" * 6 ) % ("Class" , "Images" , "Labels" , "Precision" , "Recall" , "mAP@.5" , "mAP@.5:.95" )
100
+ print (s )
101
+ pf = "%20s" + "%12i" * 2 + "%12.3g" * 4 # print format
102
+ print (pf % ("all" , total_images , total_objects , s_mp , s_mr , s_map50 , s_mean_ap ))
107
103
108
- data_loader = validator .get_dataloader (f"{ DATASETS_DIR } /coco128-seg" , 1 )
109
104
105
+ def prepare_validation (model : YOLO , args : Any ) -> Tuple [SegmentationValidator , torch .utils .data .DataLoader ]:
106
+ validator : SegmentationValidator = model .task_map [model .task ]["validator" ](args = args )
107
+ validator .data = check_det_dataset (args .data )
108
+ validator .stride = 32
110
109
validator .is_coco = True
111
110
validator .class_map = coco80_to_coco91_class ()
112
111
validator .names = model .model .names
113
112
validator .metrics .names = validator .names
114
113
validator .nc = model .model .model [- 1 ].nc
115
- validator .nm = 32
116
114
validator .process = ops .process_mask
117
115
validator .plot_masks = []
118
116
117
+ coco_data_path = DATASETS_DIR / "coco128-seg"
118
+ data_loader = validator .get_dataloader (coco_data_path .as_posix (), 1 )
119
+
119
120
return validator , data_loader
120
121
121
122
@@ -129,7 +130,7 @@ def prepare_onnx_model(model: YOLO, model_name: str) -> Tuple[onnx.ModelProto, P
129
130
130
131
131
132
def quantize_ac (
132
- model : onnx .ModelProto , data_loader : torch .utils .data .DataLoader , validator_ac : Validator
133
+ model : onnx .ModelProto , data_loader : torch .utils .data .DataLoader , validator_ac : SegmentationValidator
133
134
) -> onnx .ModelProto :
134
135
input_name = model .graph .input [0 ].name
135
136
@@ -140,7 +141,7 @@ def transform_fn(data_item: Dict):
140
141
def validation_ac (
141
142
val_model : onnx .ModelProto ,
142
143
validation_loader : torch .utils .data .DataLoader ,
143
- validator : Validator ,
144
+ validator : SegmentationValidator ,
144
145
num_samples : int = None ,
145
146
) -> float :
146
147
validator .seen = 0
@@ -155,7 +156,6 @@ def validation_ac(
155
156
output_names = [output .name for output in session .get_outputs ()]
156
157
num_outputs = len (output_names )
157
158
158
- counter = 0
159
159
for batch_i , batch in enumerate (validation_loader ):
160
160
if num_samples is not None and batch_i == num_samples :
161
161
break
@@ -172,13 +172,12 @@ def validation_ac(
172
172
]
173
173
preds = validator .postprocess (preds )
174
174
validator .update_metrics (preds , batch )
175
- counter += 1
175
+
176
176
stats = validator .get_stats ()
177
177
if num_outputs == 1 :
178
178
stats_metrics = stats ["metrics/mAP50-95(B)" ]
179
179
else :
180
180
stats_metrics = stats ["metrics/mAP50-95(M)" ]
181
- print (f"Validate: dataset length = { counter } , metric value = { stats_metrics :.3f} " )
182
181
return stats_metrics , None
183
182
184
183
quantization_dataset = nncf .Dataset (data_loader , transform_fn )
@@ -213,8 +212,6 @@ def validation_ac(
213
212
214
213
215
214
def run_example ():
216
- MODEL_NAME = "yolov8n-seg"
217
-
218
215
model = YOLO (ROOT / f"{ MODEL_NAME } .pt" )
219
216
args = get_cfg (cfg = DEFAULT_CFG )
220
217
args .data = "coco128-seg.yaml"
@@ -231,11 +228,11 @@ def run_example():
231
228
print (f"[2/5] Save INT8 model: { int8_model_path } " )
232
229
233
230
print ("[3/5] Validate ONNX FP32 model:" )
234
- fp_stats , total_images , total_objects = validate (fp32_model , tqdm ( data_loader ) , validator )
231
+ fp_stats , total_images , total_objects = validate (fp32_model , data_loader , validator )
235
232
print_statistics (fp_stats , total_images , total_objects )
236
233
237
234
print ("[4/5] Validate ONNX INT8 model:" )
238
- q_stats , total_images , total_objects = validate (int8_model , tqdm ( data_loader ) , validator )
235
+ q_stats , total_images , total_objects = validate (int8_model , data_loader , validator )
239
236
print_statistics (q_stats , total_images , total_objects )
240
237
241
238
print ("[5/5] Report:" )
0 commit comments