|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import math |
| 4 | +import os |
| 5 | + |
| 6 | +import tensorflow as tf |
| 7 | +from tensorboardX import SummaryWriter |
| 8 | +from tensorflow import keras |
| 9 | + |
| 10 | +from utils.utils import ensure_dir |
| 11 | + |
| 12 | + |
| 13 | +class BaseTrainer: |
| 14 | + """ |
| 15 | + Base class for all trainers |
| 16 | + """ |
| 17 | + |
| 18 | + def __init__(self, model, loss, metrics, resume, config, train_logger=None): |
| 19 | + self.config = config |
| 20 | + self.logger = logging.getLogger(self.__class__.__name__) |
| 21 | + self.model = model |
| 22 | + self.loss = loss |
| 23 | + self.metrics = metrics |
| 24 | + self.name = config['name'] |
| 25 | + self.epochs = config['trainer']['epochs'] |
| 26 | + self.save_freq = config['trainer']['save_freq'] |
| 27 | + self.verbosity = config['trainer']['verbosity'] |
| 28 | + self.summaryWriter = SummaryWriter() |
| 29 | + |
| 30 | + if tf.test.is_gpu_available(): |
| 31 | + if config['cuda']: |
| 32 | + self.with_cuda = True |
| 33 | + self.gpus = {i: item for i, item in enumerate(self.config['gpus'])} |
| 34 | + device = 'cuda' |
| 35 | + else: |
| 36 | + self.with_cuda = False |
| 37 | + device = 'cpu' |
| 38 | + else: |
| 39 | + self.logger.warning('Warning: There\'s no CUDA support on this machine, training is performed on CPU.') |
| 40 | + self.with_cuda = False |
| 41 | + device = 'cpu' |
| 42 | + |
| 43 | + self.device = tf.device(device) |
| 44 | + self.model.to(self.device) |
| 45 | + |
| 46 | + self.logger.debug('Model is initialized.') |
| 47 | + self._log_memory_usage() |
| 48 | + |
| 49 | + self.train_logger = train_logger |
| 50 | + |
| 51 | + self.optimizer = self.model.optimize(config['optimizer_type'], config['optimizer']) |
| 52 | + |
| 53 | + self.lr_scheduler = getattr( |
| 54 | + keras.callbacks.LearningRateScheduler, |
| 55 | + config['lr_scheduler_type'], None) |
| 56 | + if self.lr_scheduler: |
| 57 | + self.lr_scheduler = self.lr_scheduler(self.optimizer, **config['lr_scheduler']) |
| 58 | + self.lr_scheduler_freq = config['lr_scheduler_freq'] |
| 59 | + self.monitor = config['trainer']['monitor'] |
| 60 | + self.monitor_mode = config['trainer']['monitor_mode'] |
| 61 | + assert self.monitor_mode == 'min' or self.monitor_mode == 'max' |
| 62 | + self.monitor_best = math.inf if self.monitor_mode == 'min' else -math.inf |
| 63 | + self.start_epoch = 1 |
| 64 | + self.checkpoint_dir = os.path.join(config['trainer']['save_dir'], self.name) |
| 65 | + ensure_dir(self.checkpoint_dir) |
| 66 | + json.dump(config, open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'), |
| 67 | + indent=4, sort_keys=False) |
| 68 | + if resume: |
| 69 | + self._resume_checkpoint(resume) |
| 70 | + |
| 71 | + def train(self): |
| 72 | + """ |
| 73 | + Full training logic |
| 74 | + """ |
| 75 | + print(self.epochs) |
| 76 | + for epoch in range(self.start_epoch, self.epochs + 1): |
| 77 | + try: |
| 78 | + result = self._train_epoch(epoch) |
| 79 | + except tf.errors.UnavailableError: |
| 80 | + self._log_memory_usage() |
| 81 | + |
| 82 | + log = {'epoch': epoch} |
| 83 | + for key, value in result.items(): |
| 84 | + if key == 'metrics': |
| 85 | + for i, metric in enumerate(self.metrics): |
| 86 | + log[metric.__name__] = result['metrics'][i] |
| 87 | + elif key == 'val_metrics': |
| 88 | + for i, metric in enumerate(self.metrics): |
| 89 | + log['val_' + metric.__name__] = result['val_metrics'][i] |
| 90 | + else: |
| 91 | + log[key] = value |
| 92 | + if self.train_logger is not None: |
| 93 | + self.train_logger.add_entry(log) |
| 94 | + if self.verbosity >= 1: |
| 95 | + for key, value in log.items(): |
| 96 | + self.logger.info(' {:15s}: {}'.format(str(key), value)) |
| 97 | + if (self.monitor_mode == 'min' and log[self.monitor] < self.monitor_best) \ |
| 98 | + or (self.monitor_mode == 'max' and log[self.monitor] > self.monitor_best): |
| 99 | + self.monitor_best = log[self.monitor] |
| 100 | + self._save_checkpoint(epoch, log, save_best=True) |
| 101 | + if epoch % self.save_freq == 0: |
| 102 | + self._save_checkpoint(epoch, log) |
| 103 | + if self.lr_scheduler: |
| 104 | + self.lr_scheduler.step() |
| 105 | + lr = self.lr_scheduler.get_lr()[0] |
| 106 | + self.logger.info('New Learning Rate: {:.8f}'.format(lr)) |
| 107 | + |
| 108 | + self.summaryWriter.add_scalars('Train', {'train_' + self.monitor: result[self.monitor], |
| 109 | + 'val_' + self.monitor: result[self.monitor]}, epoch) |
| 110 | + self.summaryWriter.close() |
| 111 | + |
| 112 | + # TODO Not Available |
| 113 | + def _log_memory_usage(self): |
| 114 | + if not self.with_cuda: |
| 115 | + return |
| 116 | + |
| 117 | + template = """Memory Usage: \n{}""" |
| 118 | + usage = [] |
| 119 | + for deviceID, device in self.gpus.items(): |
| 120 | + deviceID = int(deviceID) |
| 121 | + # allocated = torch.cuda.memory_allocated(deviceID) / (1024 * 1024) |
| 122 | + # cached = torch.cuda.memory_cached(deviceID) / (1024 * 1024) |
| 123 | + |
| 124 | + # usage.append(' CUDA: {} Allocated: {} MB Cached: {} MB \n'.format(device, allocated, cached)) |
| 125 | + |
| 126 | + content = ''.join(usage) |
| 127 | + content = template.format(content) |
| 128 | + |
| 129 | + self.logger.debug(content) |
| 130 | + |
| 131 | + def _train_epoch(self, epoch): |
| 132 | + """ |
| 133 | + Training logic for an epoch |
| 134 | +
|
| 135 | + :param epoch: Current epoch number |
| 136 | + """ |
| 137 | + raise NotImplementedError |
| 138 | + |
| 139 | + def _save_checkpoint(self, epoch, log, save_best=False): |
| 140 | + """ |
| 141 | + Saving checkpoints |
| 142 | +
|
| 143 | + :param epoch: current epoch number |
| 144 | + :param log: logging information of the epoch |
| 145 | + :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar' |
| 146 | + """ |
| 147 | + arch = type(self.model).__name__ |
| 148 | + state = { |
| 149 | + 'arch': arch, |
| 150 | + 'epoch': epoch, |
| 151 | + 'logger': self.train_logger, |
| 152 | + 'state_dict': self.model.state_dict(), |
| 153 | + 'optimizer': self.optimizer.state_dict(), |
| 154 | + 'monitor_best': self.monitor_best, |
| 155 | + 'config': self.config |
| 156 | + } |
| 157 | + filename = os.path.join(self.checkpoint_dir, 'checkpoint-epoch{:03d}-loss-{:.4f}.pth.tar' |
| 158 | + .format(epoch, log['loss'])) |
| 159 | + tf.saved_model.save(state, filename) |
| 160 | + if save_best: |
| 161 | + os.rename(filename, os.path.join(self.checkpoint_dir, 'model_best.pth.tar')) |
| 162 | + self.logger.info("Saving current best: {} ...".format('model_best.pth.tar')) |
| 163 | + else: |
| 164 | + self.logger.info("Saving checkpoint: {} ...".format(filename)) |
| 165 | + |
| 166 | + def _resume_checkpoint(self, resume_path): |
| 167 | + """ |
| 168 | + Resume from saved checkpoints |
| 169 | +
|
| 170 | + :param resume_path: Checkpoint path to be resumed |
| 171 | + """ |
| 172 | + self.logger.info("Loading checkpoint: {} ...".format(resume_path)) |
| 173 | + checkpoint = tf.saved_model.load(resume_path) |
| 174 | + self.start_epoch = checkpoint['epoch'] + 1 |
| 175 | + self.monitor_best = checkpoint['monitor_best'] |
| 176 | + self.model.load_state_dict(checkpoint['state_dict']) |
| 177 | + self.optimizer.load_state_dict(checkpoint['optimizer']) |
| 178 | + if self.with_cuda: |
| 179 | + for state in self.optimizer.state.values(): |
| 180 | + for k, v in state.items(): |
| 181 | + if isinstance(v, tf.Tensor): |
| 182 | + state[k] = v.cuda(tf.device('cuda')) |
| 183 | + self.train_logger = checkpoint['logger'] |
| 184 | + # self.config = checkpoint['config'] |
| 185 | + self.logger.info("Checkpoint '{}' (epoch {}) loaded".format(resume_path, self.start_epoch)) |
0 commit comments