-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
31 lines (21 loc) · 819 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from misc.utils import cross_entropy2d
def train_epoch(model,optimizer,criterion,train_loader,infos,args):
iteration=infos['iteration']
epoch=infos['epoch']
for i,batch in enumerate(train_loader):
iteration+=1
optimizer.zero_grad()
inputs=batch[0]
labels=batch[1]
if args.use_cuda:
inputs=inputs.cuda()
labels=labels.cuda()
outputs=model(inputs)
loss=criterion(outputs,labels)
loss.backward()
optimizer.step()
if iteration%args.checkpoint_every==0 and args.checkpoint_every>0:
print('Epoch:{},iteration:{},train_loss:{}'.format(epoch,iteration,loss.item()))
infos['train_loss'].append(loss.item())
infos['iteration']=iteration