Skip to content

Commit 6eed8f9

Browse files
committed
Complete all modules without debug
1 parent dbf595c commit 6eed8f9

32 files changed

+3126
-174
lines changed

base/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .base_data_loader import BaseDataLoader
2+
from .base_model import BaseModel
3+
from .base_trainer import BaseTrainer

base/base_data_loader.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from copy import copy
2+
3+
import numpy as np
4+
5+
6+
class BaseDataLoader:
7+
"""
8+
Base class for all data loaders
9+
"""
10+
11+
def __init__(self, config):
12+
self.config = config
13+
self.batch_size = config['data_loader']['batch_size']
14+
self.shuffle = config['data_loader']['shuffle']
15+
self.num_workers = config['data_loader']['workers']
16+
self.activate = config['data_loader']['activate']
17+
self.val_rate = config['split_rate']['validation']
18+
self.test_rate = config['split_rate']['test']
19+
self.batch_idx = 0
20+
21+
def __iter__(self):
22+
"""
23+
:return: Iterator
24+
"""
25+
assert self.__len__() > 0
26+
self.batch_idx = 0
27+
if self.shuffle:
28+
self._shuffle_data()
29+
return self
30+
31+
def __next__(self):
32+
"""
33+
:return: Next batch
34+
"""
35+
packed = self._pack_data()
36+
if self.batch_idx < self.__len__():
37+
batch = packed[self.batch_idx * self.batch_size:(self.batch_idx + 1) * self.batch_size]
38+
self.batch_idx = self.batch_idx + 1
39+
return self._unpack_data(batch)
40+
else:
41+
raise StopIteration
42+
43+
def __len__(self):
44+
"""
45+
:return: Total number of batches
46+
"""
47+
return self._n_samples() // self.batch_size
48+
49+
def _n_samples(self):
50+
"""
51+
:return: Total number of samples
52+
"""
53+
return NotImplementedError
54+
55+
def _pack_data(self):
56+
"""
57+
Pack all data into a list/tuple/ndarray/...
58+
59+
:return: Packed data in the data loader
60+
"""
61+
return NotImplementedError
62+
63+
def _unpack_data(self, packed):
64+
"""
65+
Unpack packed data (from _pack_data())
66+
67+
:param packed: Packed data
68+
:return: Unpacked data
69+
"""
70+
return NotImplementedError
71+
72+
def _update_data(self, unpacked):
73+
"""
74+
Update data member in the data loader
75+
76+
:param unpacked: Unpacked data (from _update_data())
77+
"""
78+
return NotImplementedError
79+
80+
def _shuffle_data(self):
81+
"""
82+
Shuffle data members in the data loader
83+
"""
84+
packed = self._pack_data()
85+
rand_idx = np.random.permutation(len(packed))
86+
packed = [packed[i] for i in rand_idx]
87+
self._update_data(self._unpack_data(packed))
88+
89+
def split_validation(self):
90+
"""
91+
Split validation data from data loader based on self.config['validation']
92+
"""
93+
validation_split = self.config['validation']['validation_split']
94+
shuffle = self.config['validation']['shuffle']
95+
if validation_split == 0.0:
96+
return None
97+
if shuffle:
98+
self._shuffle_data()
99+
valid_data_loader = copy(self)
100+
split = int(self._n_samples() * validation_split)
101+
packed = self._pack_data()
102+
train_data = self._unpack_data(packed[split:])
103+
val_data = self._unpack_data(packed[:split])
104+
valid_data_loader._update_data(val_data)
105+
self._update_data(train_data)
106+
return valid_data_loader

base/base_trainer.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import json
2+
import logging
3+
import math
4+
import os
5+
6+
import tensorflow as tf
7+
from tensorboardX import SummaryWriter
8+
from tensorflow import keras
9+
10+
from utils.utils import ensure_dir
11+
12+
13+
class BaseTrainer:
14+
"""
15+
Base class for all trainers
16+
"""
17+
18+
def __init__(self, model, loss, metrics, resume, config, train_logger=None):
19+
self.config = config
20+
self.logger = logging.getLogger(self.__class__.__name__)
21+
self.model = model
22+
self.loss = loss
23+
self.metrics = metrics
24+
self.name = config['name']
25+
self.epochs = config['trainer']['epochs']
26+
self.save_freq = config['trainer']['save_freq']
27+
self.verbosity = config['trainer']['verbosity']
28+
self.summaryWriter = SummaryWriter()
29+
30+
if tf.test.is_gpu_available():
31+
if config['cuda']:
32+
self.with_cuda = True
33+
self.gpus = {i: item for i, item in enumerate(self.config['gpus'])}
34+
device = 'cuda'
35+
else:
36+
self.with_cuda = False
37+
device = 'cpu'
38+
else:
39+
self.logger.warning('Warning: There\'s no CUDA support on this machine, training is performed on CPU.')
40+
self.with_cuda = False
41+
device = 'cpu'
42+
43+
self.device = tf.device(device)
44+
self.model.to(self.device)
45+
46+
self.logger.debug('Model is initialized.')
47+
self._log_memory_usage()
48+
49+
self.train_logger = train_logger
50+
51+
self.optimizer = self.model.optimize(config['optimizer_type'], config['optimizer'])
52+
53+
self.lr_scheduler = getattr(
54+
keras.callbacks.LearningRateScheduler,
55+
config['lr_scheduler_type'], None)
56+
if self.lr_scheduler:
57+
self.lr_scheduler = self.lr_scheduler(self.optimizer, **config['lr_scheduler'])
58+
self.lr_scheduler_freq = config['lr_scheduler_freq']
59+
self.monitor = config['trainer']['monitor']
60+
self.monitor_mode = config['trainer']['monitor_mode']
61+
assert self.monitor_mode == 'min' or self.monitor_mode == 'max'
62+
self.monitor_best = math.inf if self.monitor_mode == 'min' else -math.inf
63+
self.start_epoch = 1
64+
self.checkpoint_dir = os.path.join(config['trainer']['save_dir'], self.name)
65+
ensure_dir(self.checkpoint_dir)
66+
json.dump(config, open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'),
67+
indent=4, sort_keys=False)
68+
if resume:
69+
self._resume_checkpoint(resume)
70+
71+
def train(self):
72+
"""
73+
Full training logic
74+
"""
75+
print(self.epochs)
76+
for epoch in range(self.start_epoch, self.epochs + 1):
77+
try:
78+
result = self._train_epoch(epoch)
79+
except tf.errors.UnavailableError:
80+
self._log_memory_usage()
81+
82+
log = {'epoch': epoch}
83+
for key, value in result.items():
84+
if key == 'metrics':
85+
for i, metric in enumerate(self.metrics):
86+
log[metric.__name__] = result['metrics'][i]
87+
elif key == 'val_metrics':
88+
for i, metric in enumerate(self.metrics):
89+
log['val_' + metric.__name__] = result['val_metrics'][i]
90+
else:
91+
log[key] = value
92+
if self.train_logger is not None:
93+
self.train_logger.add_entry(log)
94+
if self.verbosity >= 1:
95+
for key, value in log.items():
96+
self.logger.info(' {:15s}: {}'.format(str(key), value))
97+
if (self.monitor_mode == 'min' and log[self.monitor] < self.monitor_best) \
98+
or (self.monitor_mode == 'max' and log[self.monitor] > self.monitor_best):
99+
self.monitor_best = log[self.monitor]
100+
self._save_checkpoint(epoch, log, save_best=True)
101+
if epoch % self.save_freq == 0:
102+
self._save_checkpoint(epoch, log)
103+
if self.lr_scheduler:
104+
self.lr_scheduler.step()
105+
lr = self.lr_scheduler.get_lr()[0]
106+
self.logger.info('New Learning Rate: {:.8f}'.format(lr))
107+
108+
self.summaryWriter.add_scalars('Train', {'train_' + self.monitor: result[self.monitor],
109+
'val_' + self.monitor: result[self.monitor]}, epoch)
110+
self.summaryWriter.close()
111+
112+
# TODO Not Available
113+
def _log_memory_usage(self):
114+
if not self.with_cuda:
115+
return
116+
117+
template = """Memory Usage: \n{}"""
118+
usage = []
119+
for deviceID, device in self.gpus.items():
120+
deviceID = int(deviceID)
121+
# allocated = torch.cuda.memory_allocated(deviceID) / (1024 * 1024)
122+
# cached = torch.cuda.memory_cached(deviceID) / (1024 * 1024)
123+
124+
# usage.append(' CUDA: {} Allocated: {} MB Cached: {} MB \n'.format(device, allocated, cached))
125+
126+
content = ''.join(usage)
127+
content = template.format(content)
128+
129+
self.logger.debug(content)
130+
131+
def _train_epoch(self, epoch):
132+
"""
133+
Training logic for an epoch
134+
135+
:param epoch: Current epoch number
136+
"""
137+
raise NotImplementedError
138+
139+
def _save_checkpoint(self, epoch, log, save_best=False):
140+
"""
141+
Saving checkpoints
142+
143+
:param epoch: current epoch number
144+
:param log: logging information of the epoch
145+
:param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
146+
"""
147+
arch = type(self.model).__name__
148+
state = {
149+
'arch': arch,
150+
'epoch': epoch,
151+
'logger': self.train_logger,
152+
'state_dict': self.model.state_dict(),
153+
'optimizer': self.optimizer.state_dict(),
154+
'monitor_best': self.monitor_best,
155+
'config': self.config
156+
}
157+
filename = os.path.join(self.checkpoint_dir, 'checkpoint-epoch{:03d}-loss-{:.4f}.pth.tar'
158+
.format(epoch, log['loss']))
159+
tf.saved_model.save(state, filename)
160+
if save_best:
161+
os.rename(filename, os.path.join(self.checkpoint_dir, 'model_best.pth.tar'))
162+
self.logger.info("Saving current best: {} ...".format('model_best.pth.tar'))
163+
else:
164+
self.logger.info("Saving checkpoint: {} ...".format(filename))
165+
166+
def _resume_checkpoint(self, resume_path):
167+
"""
168+
Resume from saved checkpoints
169+
170+
:param resume_path: Checkpoint path to be resumed
171+
"""
172+
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
173+
checkpoint = tf.saved_model.load(resume_path)
174+
self.start_epoch = checkpoint['epoch'] + 1
175+
self.monitor_best = checkpoint['monitor_best']
176+
self.model.load_state_dict(checkpoint['state_dict'])
177+
self.optimizer.load_state_dict(checkpoint['optimizer'])
178+
if self.with_cuda:
179+
for state in self.optimizer.state.values():
180+
for k, v in state.items():
181+
if isinstance(v, tf.Tensor):
182+
state[k] = v.cuda(tf.device('cuda'))
183+
self.train_logger = checkpoint['logger']
184+
# self.config = checkpoint['config']
185+
self.logger.info("Checkpoint '{}' (epoch {}) loaded".format(resume_path, self.start_epoch))

config.json

+12-5
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,23 @@
44
"gpus": [
55
0
66
],
7+
"training": "True",
78
"data_loader": {
8-
"dataset": "ICDAR 2019 - LSVT",
9-
"data_dir": "F:\\Code\\HealthHelper\\Dataset\\ICDAR 2019 - LSVT",
9+
"activate": 0,
10+
"datasets": [
11+
{
12+
"name": "ICDAR 2019 LSVT",
13+
"data_dir": "F:\\Code\\HealthHelper\\Dataset\\ICDAR 2019 - LSVT",
14+
"have_test": "False"
15+
}
16+
],
1017
"batch_size": 128,
1118
"shuffle": true,
1219
"workers": 0
1320
},
14-
"validation": {
15-
"validation_split": 0.2,
16-
"shuffle": true
21+
"split_rate": {
22+
"validation": 0.1,
23+
"test": 0.1
1724
},
1825
"lr_scheduler_type": "",
1926
"lr_scheduler_freq": 10000,

data_loader/data_loader.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import tensorflow as tf
2+
from sklearn import model_selection
3+
4+
from base import BaseDataLoader
5+
from .datasets import ICDAR2019Dataset
6+
7+
8+
class ICDAR2019DataLoaderFactory(BaseDataLoader):
9+
10+
def __init__(self, config):
11+
super(ICDAR2019DataLoaderFactory, self).__init__(config)
12+
dataRoot = config['data_loader']['datasets'][self.activate]['data_dir']
13+
self.workers = config['data_loader']['workers']
14+
self.have_test = config['data_loader']['datasets'][self.activate]['have_test']
15+
dataset = ICDAR2019Dataset(dataRoot)
16+
self.allDataset = dataset.loadData()
17+
18+
if self.have_test:
19+
self.trainDataset, self.testDataset = self.train_val_split(self.allDataset)
20+
self.trainDataset, self.valDataset = self.train_val_split(self.trainDataset)
21+
else:
22+
self.trainDataset, self.valDataset = self.train_val_split(self.allDataset)
23+
24+
def train(self):
25+
trainLoader = tf.data.Dataset.from_tensor_slices(self.trainDataset)
26+
# trainLoader = torchdata.DataLoader(self.trainDataset, num_workers=self.num_workers,
27+
# batch_size=self.batch_size,
28+
# shuffle=self.shuffle, collate_fn=collate_fn)
29+
return trainLoader
30+
31+
def val(self):
32+
# valLoader = torchdata.DataLoader(self.valDataset, num_workers=self.num_workers, batch_size=self.batch_size,
33+
# shuffle=shuffle, collate_fn=collate_fn)
34+
valLoader = tf.data.Dataset.from_tensor_slices(self.valDataset)
35+
return valLoader
36+
37+
def train_val_split(self, dataset):
38+
"""
39+
40+
:param dataset: dataset
41+
:return:
42+
"""
43+
train, val = model_selection.train_test_split(dataset[0], tuple(dataset[1:]), test_size=self.val_rate)
44+
return train, val
45+
46+
def train_test_split(self, dataset):
47+
train, test = model_selection.train_test_split(dataset[0], dataset[1:], test_size=self.test_rate)
48+
return train, test
49+
50+
def split_validation(self):
51+
raise NotImplementedError

0 commit comments

Comments
 (0)