Skip to content

Commit c4a754c

Browse files
committed
Merge branch 'version_update_for_release' of https://github.com/sarthakpati/GaNDLF into imagenet_acs
2 parents fc0ddab + 38141df commit c4a754c

17 files changed

+183
-24
lines changed

GANDLF/compute/forward_pass.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
get_unique_timestamp,
1616
resample_image,
1717
reverse_one_hot,
18+
get_ground_truths_and_predictions_tensor,
1819
)
20+
from GANDLF.metrics import overall_stats
1921
from tqdm import tqdm
2022

2123

@@ -99,6 +101,13 @@ def validate_network(
99101
"output_predictions_" + get_unique_timestamp() + ".csv",
100102
)
101103

104+
# get ground truths for classification problem, validation set
105+
if is_classification and mode == "validation":
106+
(
107+
ground_truth_array,
108+
predictions_array,
109+
) = get_ground_truths_and_predictions_tensor(params, "validation_data")
110+
102111
for batch_idx, (subject) in enumerate(
103112
tqdm(valid_dataloader, desc="Looping over " + mode + " data")
104113
):
@@ -192,6 +201,11 @@ def validate_network(
192201
final_loss, final_metric = get_loss_and_metrics(
193202
image, valuesToPredict, pred_output, params
194203
)
204+
205+
if is_classification and mode == "validation":
206+
predictions_array[batch_idx] = (
207+
torch.argmax(pred_output[0], 0).cpu().item()
208+
)
195209
# # Non network validation related
196210
total_epoch_valid_loss += final_loss.detach().cpu().item()
197211
for metric in final_metric.keys():
@@ -283,7 +297,7 @@ def validate_network(
283297
attention_map, patches_batch[torchio.LOCATION]
284298
)
285299
else:
286-
_, _, output = result
300+
_, _, output, _ = result
287301

288302
if params["problem_type"] == "segmentation":
289303
aggregator.add_batch(
@@ -359,6 +373,10 @@ def validate_network(
359373
else:
360374
# final regression output
361375
output_prediction = output_prediction / len(patch_loader)
376+
if is_classification and mode == "validation":
377+
predictions_array[batch_idx] = (
378+
torch.argmax(output_prediction[0], 0).cpu().item()
379+
)
362380
if params["save_output"]:
363381
outputToWrite += (
364382
str(epoch)
@@ -453,6 +471,11 @@ def validate_network(
453471
if label_ground_truth is not None:
454472
average_epoch_valid_loss = total_epoch_valid_loss / len(valid_dataloader)
455473
print(" Epoch Final " + mode + " loss : ", average_epoch_valid_loss)
474+
# get overall stats for classification
475+
if is_classification and mode == "validation":
476+
average_epoch_valid_metric = overall_stats(
477+
predictions_array, ground_truth_array, params
478+
)
456479
for metric in params["metrics"]:
457480
if isinstance(total_epoch_valid_metric[metric], np.ndarray):
458481
to_print = (
@@ -461,6 +484,7 @@ def validate_network(
461484
else:
462485
to_print = total_epoch_valid_metric[metric] / len(valid_dataloader)
463486
average_epoch_valid_metric[metric] = to_print
487+
for metric in average_epoch_valid_metric.keys():
464488
print(
465489
" Epoch Final " + mode + " " + metric + " : ",
466490
average_epoch_valid_metric[metric],

GANDLF/compute/step.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def step(model, image, label, params, train=True):
8080
else:
8181
output = model(image)
8282

83+
attention_map = None
8384
if "medcam_enabled" in params and params["medcam_enabled"]:
8485
output, attention_map = output
8586

@@ -97,7 +98,4 @@ def step(model, image, label, params, train=True):
9798
if "medcam_enabled" in params and params["medcam_enabled"]:
9899
attention_map = torch.unsqueeze(attention_map, -1)
99100

100-
if not ("medcam_enabled" in params and params["medcam_enabled"]):
101-
return loss, metric_output, output
102-
else:
103-
return loss, metric_output, output, attention_map
101+
return loss, metric_output, output, attention_map

GANDLF/compute/training_loop.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
version_check,
1717
write_training_patches,
1818
print_model_summary,
19+
get_ground_truths_and_predictions_tensor,
1920
)
21+
from GANDLF.metrics import overall_stats
2022
from GANDLF.logger import Logger
2123
from .step import step
2224
from .forward_pass import validate_network
@@ -69,6 +71,12 @@ def train_network(model, train_dataloader, optimizer, params):
6971
if params["verbose"]:
7072
print("Using Automatic mixed precision", flush=True)
7173

74+
# get ground truths
75+
if params["problem_type"] == "classification":
76+
(
77+
ground_truth_array,
78+
predictions_array,
79+
) = get_ground_truths_and_predictions_tensor(params, "training_data")
7280
# Set the model to train
7381
model.train()
7482
for batch_idx, (subject) in enumerate(
@@ -104,7 +112,15 @@ def train_network(model, train_dataloader, optimizer, params):
104112
params["subject_spacing"] = subject["spacing"]
105113
else:
106114
params["subject_spacing"] = None
107-
loss, calculated_metrics, _ = step(model, image, label, params)
115+
loss, calculated_metrics, output, _ = step(model, image, label, params)
116+
# store predictions for classification
117+
if params["problem_type"] == "classification":
118+
predictions_array[
119+
batch_idx
120+
* params["batch_size"] : (batch_idx + 1)
121+
* params["batch_size"]
122+
] = (torch.argmax(output[0], 0).cpu().item())
123+
108124
nan_loss = torch.isnan(loss)
109125
second_order = (
110126
hasattr(optimizer, "is_second_order") and optimizer.is_second_order
@@ -175,6 +191,12 @@ def train_network(model, train_dataloader, optimizer, params):
175191

176192
average_epoch_train_loss = total_epoch_train_loss / len(train_dataloader)
177193
print(" Epoch Final train loss : ", average_epoch_train_loss)
194+
195+
# get overall stats for classification
196+
if params["problem_type"] == "classification":
197+
average_epoch_train_metric = overall_stats(
198+
predictions_array, ground_truth_array, params
199+
)
178200
for metric in params["metrics"]:
179201
if isinstance(total_epoch_train_metric[metric], np.ndarray):
180202
to_print = (
@@ -183,6 +205,7 @@ def train_network(model, train_dataloader, optimizer, params):
183205
else:
184206
to_print = total_epoch_train_metric[metric] / len(train_dataloader)
185207
average_epoch_train_metric[metric] = to_print
208+
for metric in average_epoch_train_metric.keys():
186209
print(
187210
" Epoch Final train " + metric + " : ",
188211
average_epoch_train_metric[metric],

GANDLF/metrics/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from .regression import classification_accuracy, balanced_acc_score, per_label_accuracy
1414
from .generic import recall_score, precision_score, iou_score, f1_score, accuracy
15+
from .classification import overall_stats
1516

1617

1718
# global defines for the metrics
@@ -35,5 +36,5 @@
3536
"recall": recall_score,
3637
"iou": iou_score,
3738
"balanced_accuracy": balanced_acc_score,
38-
"per_label_accuracy": per_label_accuracy,
39+
"per_label_one_hot_accuracy": per_label_accuracy,
3940
}

GANDLF/metrics/classification.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torchmetrics as tm
2+
3+
4+
def get_output_from_calculator(predictions, ground_truth, calculator):
5+
"""
6+
Helper function to get the output from a calculator.
7+
8+
Args:
9+
predictions (torch.Tensor): The output of the model.
10+
ground_truth (torch.Tensor): The ground truth labels.
11+
calculator (torchmetrics.Metric): The calculator to use.
12+
13+
Returns:
14+
float: The output from the calculator.
15+
"""
16+
temp_output = calculator(predictions, ground_truth)
17+
if temp_output.dim() > 0:
18+
temp_output = temp_output.cpu().tolist()
19+
else:
20+
temp_output = temp_output.cpu().item()
21+
return temp_output
22+
23+
24+
def overall_stats(predictions, ground_truth, params):
25+
"""
26+
Generates a dictionary of metrics calculated on the overall predictions and ground truths.
27+
28+
Args:
29+
predictions (torch.Tensor): The output of the model.
30+
ground_truth (torch.Tensor): The ground truth labels.
31+
params (dict): The parameter dictionary containing training and data information.
32+
33+
Returns:
34+
dict: A dictionary of metrics.
35+
"""
36+
assert (
37+
params["problem_type"] == "classification"
38+
), "Only classification is supported for overall stats"
39+
assert len(predictions) == len(
40+
ground_truth
41+
), "Predictions and ground truth must be of same length"
42+
43+
output_metrics = {}
44+
45+
average_types_keys = {
46+
"global": "micro",
47+
"per_class": "none",
48+
"per_class_average": "macro",
49+
"per_class_weighted": "weighted",
50+
}
51+
# metrics that need the "average" parameter
52+
for average_type, average_type_key in average_types_keys.items():
53+
calculators = {
54+
"accuracy": tm.Accuracy(
55+
num_classes=params["model"]["num_classes"], average=average_type_key
56+
),
57+
"precision": tm.Precision(
58+
num_classes=params["model"]["num_classes"], average=average_type_key
59+
),
60+
"recall": tm.Recall(
61+
num_classes=params["model"]["num_classes"], average=average_type_key
62+
),
63+
"f1": tm.F1(
64+
num_classes=params["model"]["num_classes"], average=average_type_key
65+
),
66+
## weird error for multi-class problem, where pos_label is not getting set
67+
# "aucroc": tm.AUROC(
68+
# num_classes=params["model"]["num_classes"], average=average_type_key
69+
# ),
70+
}
71+
for metric_name, calculator in calculators.items():
72+
output_metrics[
73+
f"{metric_name}_{average_type}"
74+
] = get_output_from_calculator(predictions, ground_truth, calculator)
75+
# metrics that do not have any "average" parameter
76+
calculators = {
77+
"auc": tm.AUC(reorder=True),
78+
## weird error for multi-class problem, where pos_label is not getting set
79+
# "roc": tm.ROC(num_classes=params["model"]["num_classes"]),
80+
}
81+
for metric_name, calculator in calculators.items():
82+
output_metrics[metric_name] = get_output_from_calculator(
83+
predictions, ground_truth, calculator
84+
)
85+
86+
return output_metrics

GANDLF/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
get_class_imbalance_weights_classification,
2020
get_linear_interpolation_mode,
2121
print_model_summary,
22+
get_ground_truths_and_predictions_tensor,
2223
)
2324

2425
from .write_parse import (

GANDLF/utils/tensor.py

+23
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,26 @@ def print_model_summary(
419419
)
420420
temp_output = stats.to_readable(stats.total_mult_adds)
421421
print("\tTotal # of operations:", temp_output[1], temp_output[0])
422+
423+
424+
def get_ground_truths_and_predictions_tensor(params, loader_type):
425+
"""
426+
This function is used to get the ground truths and predictions for a given loader type.
427+
428+
Args:
429+
params (dict): The parameters passed by the user yaml.
430+
loader_type (str): The loader type for which the ground truths and predictions are to be returned.
431+
432+
Returns:
433+
torch.Tensor, torch.Tensor: The ground truths and base predictions for the given loader type.
434+
"""
435+
ground_truth_array = torch.from_numpy(
436+
params[loader_type][
437+
params[loader_type].columns[params["headers"]["predictionHeaders"]]
438+
]
439+
.to_numpy()
440+
.ravel()
441+
).type(torch.int)
442+
predictions_array = torch.zeros_like(ground_truth_array)
443+
444+
return ground_truth_array, predictions_array

GANDLF/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/usr/bin/env python3
22
# -*- coding: UTF-8 -*-
3-
__version__ = "0.0.15-dev"
3+
__version__ = "0.0.16-dev"

HISTORY.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.0.16-
2+
- ImageNet pre-trained models for UNet with variable encoders is now available
3+
- ACS/Soft conversion is available for ImageNet-pretrained UNet
4+
15
## 0.0.15
26
- Updated `setup.py` for `python>=3.8`
37
- `stride_size` is now handled internally for histology data
@@ -9,8 +13,7 @@
913
- Per class accuracy has been added as a metric
1014
- Dedicated rescaling preprocessing function added for increased flexibility
1115
- Largest Connected Component Analysis is now added
12-
- ImageNet pre-trained models for UNet with variable encoders is now available
13-
- ACS/Soft conversion is available for ImageNet-pretrained UNet
16+
- Included metrics using overall predictions and ground truths
1417

1518
## 0.0.14
1619

samples/config_all_options.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# affix version
22
version:
33
{
4-
minimum: 0.0.15,
5-
maximum: 0.0.15 # this should NOT be made a variable, but should be tested after every tag is created
4+
minimum: 0.0.16,
5+
maximum: 0.0.16 # this should NOT be made a variable, but should be tested after every tag is created
66
}
77
## Choose the model parameters here
88
model:

samples/config_classification.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# affix version
22
version:
33
{
4-
minimum: 0.0.15,
5-
maximum: 0.0.15 # this should NOT be made a variable, but should be tested after every tag is created
4+
minimum: 0.0.16,
5+
maximum: 0.0.16 # this should NOT be made a variable, but should be tested after every tag is created
66
}
77
# Choose the model parameters here
88
model:

samples/config_regression.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# affix version
22
version:
33
{
4-
minimum: 0.0.15,
5-
maximum: 0.0.15 # this should NOT be made a variable, but should be tested after every tag is created
4+
minimum: 0.0.16,
5+
maximum: 0.0.16 # this should NOT be made a variable, but should be tested after every tag is created
66
}
77
# Choose the model parameters here
88
model:

samples/config_segmentation_brats.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# affix version
22
version:
33
{
4-
minimum: 0.0.15,
5-
maximum: 0.0.15 # this should NOT be made a variable, but should be tested after every tag is created
4+
minimum: 0.0.16,
5+
maximum: 0.0.16 # this should NOT be made a variable, but should be tested after every tag is created
66
}
77
# Choose the model parameters here
88
model:

samples/config_segmentation_histology.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# affix version
22
version:
33
{
4-
minimum: 0.0.15,
5-
maximum: 0.0.15 # this should NOT be made a variable, but should be tested after every tag is created
4+
minimum: 0.0.16,
5+
maximum: 0.0.16 # this should NOT be made a variable, but should be tested after every tag is created
66
}
77
# Choose the model parameters here
88
model:

testing/config_classification.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ metrics:
1414
}
1515
- accuracy
1616
- balanced_accuracy
17-
- per_label_accuracy
17+
- per_label_one_hot_accuracy
1818
- precision: {
1919
average: weighted,
2020
}
@@ -53,7 +53,7 @@ save_output: false
5353
scaling_factor: 1
5454
scheduler: triangle
5555
version:
56-
maximum: 0.0.15
56+
maximum: 0.0.16
5757
minimum: 0.0.14
5858
weighted_loss: True
5959
which_model: resunet

testing/config_regression.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ save_output: false
3737
scaling_factor: 1
3838
scheduler: triangle
3939
version:
40-
maximum: 0.0.15
40+
maximum: 0.0.16
4141
minimum: 0.0.14
4242
weighted_loss: false
4343
which_model: resunet

0 commit comments

Comments
 (0)