forked from intel/neural-compressor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
136 lines (112 loc) · 5.34 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright 2018 Dong-Hyun Lee, Kakao Brain.
#
# Copyright (c) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Training Config & Helper Classes """
import os
from typing import NamedTuple
from tqdm import tqdm
import torch
import torch.nn as nn
import checkpoint
class Config(NamedTuple):
""" Hyperparameters for optimization """
seed: int = 3431 # random seed
batch_size: int = 32
lr: int = 5e-5 # learning rate
n_epochs: int = 10 # the number of epoch
# `warm up` period = warmup(0.1)*total_steps
# linearly increasing learning rate from zero to the specified value(5e-5)
warmup: float = 0.1
save_steps: int = 100 # interval for saving model
total_steps: int = 100000 # total number of steps to train
data_parallel: bool = False
comments: str = "" # for comments in json file
class TrainLoop(object):
"""Training Helper Class"""
def __init__(self, cfg, model, data_iter, optimizer, save_dir, device):
self.cfg = cfg # config for training : see class Config
self.model = model
self.data_iter = data_iter # iterator to load data
self.optimizer = optimizer
self.save_dir = save_dir
self.device = device # device name
def train(self, get_loss, model_file=None, pretrain_file=None, data_parallel=True):
""" Train Loop """
self.model.train() # train mode
self.load(model_file, pretrain_file)
model = self.model.to(self.device)
if data_parallel: # use Data Parallelism with Multi-GPU
model = nn.DataParallel(model)
global_step = 0 # global iteration steps regardless of epochs
for e in range(self.cfg.n_epochs):
loss_sum = 0. # the sum of iteration losses to get average loss in every epoch
iter_bar = tqdm(self.data_iter, desc='Iter (loss=X.XXX)')
for i, batch in enumerate(iter_bar):
batch = [t.to(self.device) for t in batch]
self.optimizer.zero_grad()
loss = get_loss(model, batch, global_step).mean() # mean() for Data Parallelism
loss.backward()
self.optimizer.step()
global_step += 1
loss_sum += loss.item()
iter_bar.set_description('Iter (loss=%5.3f)'%loss.item())
if global_step % self.cfg.save_steps == 0: # save
self.save(global_step)
if self.cfg.total_steps and self.cfg.total_steps < global_step:
print(f"Epoch {e+1}/{self.cfg.n_epochs} : Average Loss {loss_sum/(i+1)}")
print('The Total Steps have been reached.')
self.save(final=True) # save and finish when global_steps reach total_steps
return
print(f"Epoch {e+1}/{self.cfg.n_epochs} : Average Loss {loss_sum/(i+1)}")
self.save(final=True)
def eval(self, evaluate, model_file):
""" Evaluation Loop """
self.model.eval() # evaluation mode
self.load(model_file, None)
model = self.model.to(self.device)
if self.cfg.data_parallel: # use Data Parallelism with Multi-GPU
model = nn.DataParallel(model)
results = [] # prediction results
iter_bar = tqdm(self.data_iter, desc='Iter (loss=X.XXX)')
for batch in iter_bar:
batch = [t.to(self.device) for t in batch]
with torch.no_grad(): # evaluation without gradient calculation
accuracy, result = evaluate(model, batch) # accuracy to print
results.append(result)
iter_bar.set_description('Iter(acc=%5.3f)'%accuracy)
return results
def load(self, model_file, pretrain_file):
""" load saved model or pretrained transformer (a part of model) """
if model_file:
print('Loading the model from', model_file)
self.model.load_state_dict(torch.load(model_file))
elif pretrain_file: # use pretrained transformer
print('Loading the pretrained model from', pretrain_file)
if pretrain_file.endswith('.ckpt'): # checkpoint file in tensorflow
checkpoint.load_model(self.model.transformer, pretrain_file)
elif pretrain_file.endswith('.pt'): # pretrain model file in pytorch
self.model.transformer.load_state_dict(
{key[12:]: value
for key, value in torch.load(pretrain_file).items()
if key.startswith('transformer')}
) # load only transformer parts
return self.model
def save(self, i=0, final=False):
""" save current model """
file = "model_final.pt" if final else f"model_steps_{i}.pt"
# save model object before nn.DataParallel
torch.save(self.model.state_dict(), os.path.join(self.save_dir, file))
if __name__ == '__main__':
pass