15
15
get_unique_timestamp ,
16
16
resample_image ,
17
17
reverse_one_hot ,
18
+ get_ground_truths_and_predictions_tensor ,
18
19
)
20
+ from GANDLF .metrics import overall_stats
19
21
from tqdm import tqdm
20
22
21
23
@@ -99,6 +101,13 @@ def validate_network(
99
101
"output_predictions_" + get_unique_timestamp () + ".csv" ,
100
102
)
101
103
104
+ # get ground truths for classification problem, validation set
105
+ if is_classification and mode == "validation" :
106
+ (
107
+ ground_truth_array ,
108
+ predictions_array ,
109
+ ) = get_ground_truths_and_predictions_tensor (params , "validation_data" )
110
+
102
111
for batch_idx , (subject ) in enumerate (
103
112
tqdm (valid_dataloader , desc = "Looping over " + mode + " data" )
104
113
):
@@ -192,6 +201,11 @@ def validate_network(
192
201
final_loss , final_metric = get_loss_and_metrics (
193
202
image , valuesToPredict , pred_output , params
194
203
)
204
+
205
+ if is_classification and mode == "validation" :
206
+ predictions_array [batch_idx ] = (
207
+ torch .argmax (pred_output [0 ], 0 ).cpu ().item ()
208
+ )
195
209
# # Non network validation related
196
210
total_epoch_valid_loss += final_loss .detach ().cpu ().item ()
197
211
for metric in final_metric .keys ():
@@ -283,7 +297,7 @@ def validate_network(
283
297
attention_map , patches_batch [torchio .LOCATION ]
284
298
)
285
299
else :
286
- _ , _ , output = result
300
+ _ , _ , output , _ = result
287
301
288
302
if params ["problem_type" ] == "segmentation" :
289
303
aggregator .add_batch (
@@ -359,6 +373,10 @@ def validate_network(
359
373
else :
360
374
# final regression output
361
375
output_prediction = output_prediction / len (patch_loader )
376
+ if is_classification and mode == "validation" :
377
+ predictions_array [batch_idx ] = (
378
+ torch .argmax (output_prediction [0 ], 0 ).cpu ().item ()
379
+ )
362
380
if params ["save_output" ]:
363
381
outputToWrite += (
364
382
str (epoch )
@@ -453,6 +471,11 @@ def validate_network(
453
471
if label_ground_truth is not None :
454
472
average_epoch_valid_loss = total_epoch_valid_loss / len (valid_dataloader )
455
473
print (" Epoch Final " + mode + " loss : " , average_epoch_valid_loss )
474
+ # get overall stats for classification
475
+ if is_classification and mode == "validation" :
476
+ average_epoch_valid_metric = overall_stats (
477
+ predictions_array , ground_truth_array , params
478
+ )
456
479
for metric in params ["metrics" ]:
457
480
if isinstance (total_epoch_valid_metric [metric ], np .ndarray ):
458
481
to_print = (
@@ -461,6 +484,7 @@ def validate_network(
461
484
else :
462
485
to_print = total_epoch_valid_metric [metric ] / len (valid_dataloader )
463
486
average_epoch_valid_metric [metric ] = to_print
487
+ for metric in average_epoch_valid_metric .keys ():
464
488
print (
465
489
" Epoch Final " + mode + " " + metric + " : " ,
466
490
average_epoch_valid_metric [metric ],
0 commit comments