-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
209 lines (175 loc) · 9.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import sys
sys.path.append('./src')
import click
import utils
import const
from data.datasets import *
from data.dataloaders import *
from model.losses import *
from pipeline import Pipeline, METRICS_DICT
@click.group()
def cli():
"""Segmentation pipeline with additional utilities."""
@cli.command(short_help='Build and train the model. Heavy augs and warm start are supported.')
@click.option('--launch', help='launch location. used to determine default paths',
type=click.Choice(['local', 'server']), default='server', show_default=True)
@click.option('--architecture', 'model_architecture', help='model architecture (unet, mnet2)',
type=click.Choice(['unet', 'mnet2']), default='unet', show_default=True)
@click.option('--device', help='device to use',
type=click.Choice(['cpu', 'cuda:0', 'cuda:1']), default='cuda:0', show_default=True)
@click.option('--dataset', 'dataset_type', help='dataset type',
type=click.Choice(['nifti', 'numpy']), default='numpy', show_default=True)
@click.option('--heavy-augs/--no-heavy-augs', 'apply_heavy_augs',
help='whether to apply different number of augmentations for hard and regular train images'
' (uses docs/hard_cases_mapping.csv to identify hard cases)',
default=True, show_default=True)
@click.option('--epochs', 'n_epochs', help='max number of epochs to train',
type=click.INT, required=True)
@click.option('--out', 'out_dp', help='directory path to store artifacts',
type=click.STRING, default=None)
@click.option('--max-batches', help='max number of batches to process. use as sanity check. '
'if no value passed than will process the whole dataset.',
type=click.INT, default=None)
@click.option('--checkpoint', 'initial_checkpoint_fp', help='path to initial .pth checkpoint for warm start',
type=click.STRING, default=None)
def train(
launch: str, model_architecture: str, device: str, dataset_type: str,
apply_heavy_augs: bool, n_epochs: int, out_dp: str, max_batches: int,
initial_checkpoint_fp: str
):
"""Build and train the model. Heavy augs and warm start are supported."""
loss_func = METRICS_DICT['NegDiceLoss']
metrics = [
METRICS_DICT['BCELoss'],
METRICS_DICT['NegDiceLoss'],
METRICS_DICT['FocalLoss']
]
const.set_launch_type_env_var(launch == 'local')
data_paths = const.DataPaths()
split = utils.load_split_from_yaml(const.TRAIN_VALID_SPLIT_FP)
if dataset_type == 'nifti':
train_dataset = NiftiDataset(data_paths.scans_dp, data_paths.masks_dp, split['train'])
valid_dataset = NiftiDataset(data_paths.scans_dp, data_paths.masks_dp, split['valid'])
elif dataset_type == 'numpy':
ndp = const.NumpyDataPaths(data_paths.default_numpy_dataset_dp)
train_dataset = NumpyDataset(ndp.scans_dp, ndp.masks_dp, ndp.shapes_fp, split['train'])
valid_dataset = NumpyDataset(ndp.scans_dp, ndp.masks_dp, ndp.shapes_fp, split['valid'])
else:
raise ValueError(f"`dataset` should be in ['nifti', 'numpy']. passed '{dataset_type}'")
# init train data loader
if apply_heavy_augs:
print('\nwill apply heavy augmentations for train images')
# set different augmentations for hard and general cases
ids_hard_train = utils.get_image_ids_with_hard_cases_in_train_set(
const.HARD_CASES_MAPPING, const.TRAIN_VALID_SPLIT_FP
)
train_dataset.set_different_aug_cnt_for_two_subsets(1, ids_hard_train, 3)
# init loader
train_loader = DataLoaderNoAugmentations(train_dataset, batch_size=4, to_shuffle=True)
else:
print('\nwill apply the same augmentations for all train images')
train_loader = DataLoaderWithAugmentations(
train_dataset, orig_img_per_batch=2, aug_cnt=1, to_shuffle=True
)
valid_loader = DataLoaderNoAugmentations(valid_dataset, batch_size=4, to_shuffle=False)
device_t = torch.device(device)
pipeline = Pipeline(model_architecture=model_architecture, device=device_t)
pipeline.train(
train_loader=train_loader, valid_loader=valid_loader,
n_epochs=n_epochs, loss_func=loss_func, metrics=metrics,
out_dp=out_dp, max_batches=max_batches, initial_checkpoint_fp=initial_checkpoint_fp
)
@cli.command(short_help='Segment scans with already trained model.')
@click.option('--launch', help='launch location. used to determine default paths',
type=click.Choice(['local', 'server']), default='server', show_default=True)
@click.option('--architecture', 'model_architecture', help='model architecture (unet, mnet2)',
type=click.Choice(['unet', 'mnet2']), default='unet', show_default=True)
@click.option('--device', help='device to use',
type=click.Choice(['cpu', 'cuda:0', 'cuda:1']), default='cuda:0', show_default=True)
@click.option('--checkpoint', 'checkpoint_fp',
help='path to checkpoint .pth file',
type=click.STRING, default=None)
@click.option('--scans', 'scans_dp', help='path to directory with nifti scans',
type=click.STRING, default=None)
@click.option('--subset', help='what scans to segment under --scans dir: '
'either all, or the ones from "validation" dataset',
type=click.Choice(['all', 'validation']), default='all', show_default=True)
@click.option('--out', 'output_dp', help='path to output directory with segmented masks',
type=click.STRING, default=None)
@click.option('--postfix', help='postfix to set for segmented masks',
type=click.STRING, default='autolungs', show_default=True)
def segment_scans(
launch: str, model_architecture: str, device: str,
checkpoint_fp: str, scans_dp: str, subset: str,
output_dp: str, postfix: str
):
"""Segment Nifti `.nii.gz` scans with already trained model stored in `.pth` file."""
const.set_launch_type_env_var(launch == 'local')
data_paths = const.DataPaths()
device_t = torch.device(device)
pipeline = Pipeline(model_architecture=model_architecture, device=device_t)
scans_dp = scans_dp or data_paths.scans_dp
ids_list = None
if subset == 'validation':
split = utils.load_split_from_yaml(const.TRAIN_VALID_SPLIT_FP)
ids_list = split['valid']
pipeline.segment_scans(
checkpoint_fp=checkpoint_fp, scans_dp=scans_dp,
ids=ids_list, output_dp=output_dp, postfix=postfix
)
@cli.command(short_help='Find optimal LR for training with 1-cycle policy.')
@click.option('--launch', help='launch location. used to determine default paths',
type=click.Choice(['local', 'server']), default='server', show_default=True)
@click.option('--architecture', 'model_architecture', help='model architecture (unet, mnet2)',
type=click.Choice(['unet', 'mnet2']), default='unet', show_default=True)
@click.option('--device', help='device to use',
type=click.Choice(['cpu', 'cuda:0', 'cuda:1']), default='cuda:0', show_default=True)
@click.option('--dataset', 'dataset_type', help='dataset type',
type=click.Choice(['nifti', 'numpy']), default='numpy', show_default=True)
@click.option('--out', 'out_dp', help='directory path to store artifacts',
type=click.STRING, default=None)
def lr_find(
launch: str, model_architecture: str, device: str,
dataset_type: str, out_dp: str
):
"""Find optimal LR for training with 1-cycle policy."""
const.set_launch_type_env_var(launch == 'local')
data_paths = const.DataPaths()
split = utils.load_split_from_yaml(const.TRAIN_VALID_SPLIT_FP)
if dataset_type == 'nifti':
train_dataset = NiftiDataset(data_paths.scans_dp, data_paths.masks_dp, split['train'])
elif dataset_type == 'numpy':
ndp = const.NumpyDataPaths(data_paths.default_numpy_dataset_dp)
train_dataset = NumpyDataset(ndp.scans_dp, ndp.masks_dp, ndp.shapes_fp, split['train'])
else:
raise ValueError(f"`dataset` should be in ['nifti', 'numpy']. passed '{dataset_type}'")
loss_func = METRICS_DICT['NegDiceLoss']
train_loader = DataLoaderNoAugmentations(train_dataset, batch_size=4, to_shuffle=True)
device_t = torch.device(device)
pipeline = Pipeline(model_architecture=model_architecture, device=device_t)
pipeline.lr_find_and_store(loss_func=loss_func, train_loader=train_loader, out_dp=out_dp)
@cli.command(short_help='Create numpy dataset from initial Nifti scans.')
@click.option('--launch', help='launch location. used to determine default paths',
type=click.Choice(['local', 'server']), default='server', show_default=True)
@click.option('--scans', 'scans_dp', help='path to directory with nifti scans',
type=click.STRING, default=None)
@click.option('--masks', 'masks_dp', help='path to directory with nifti binary masks',
type=click.STRING, default=None)
@click.option('--zoom', 'zoom_factor', help='zoom factor for output images',
type=click.FLOAT, default=0.25, show_default=True)
@click.option('--out', 'output_dp', help='path to output directory with numpy dataset',
type=click.STRING, default=None)
def create_numpy_dataset(
launch: str, scans_dp: str, masks_dp: str, zoom_factor: float, output_dp: str
):
"""Create numpy dataset from initial Nifti `.nii.gz` scans to speedup the training."""
const.set_launch_type_env_var(launch == 'local')
data_paths = const.DataPaths()
scans_dp = scans_dp or data_paths.scans_dp
masks_dp = masks_dp or data_paths.masks_dp
numpy_data_root_dp = data_paths.get_numpy_data_root_dp(zoom_factor=zoom_factor)
output_dp = output_dp or numpy_data_root_dp
ds = NiftiDataset(scans_dp, masks_dp)
ds.store_as_numpy_dataset(output_dp, zoom_factor)
if __name__ == '__main__':
cli()