-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_eval_utils.py
76 lines (49 loc) · 1.89 KB
/
train_eval_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import sys
from tqdm import tqdm
import torch
from distributed_utils import reduce_value, is_main_process
def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
loss_function = torch.nn.CrossEntropyLoss()
mean_loss = torch.zeros(1).to(device)
optimizer.zero_grad()
# 在进程0中打印训练进度
if is_main_process():
data_loader = tqdm(data_loader)
for step, data in enumerate(data_loader):
images, labels = data
pred = model(images.to(device))
loss = loss_function(pred, labels.to(device))
loss.backward()
loss = reduce_value(loss, average=True)
mean_loss = (mean_loss * step + loss.detach()) / (step + 1) # update mean losses
# 在进程0中打印平均loss
if is_main_process():
data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)
optimizer.step()
optimizer.zero_grad()
# 等待所有进程计算完毕
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
return mean_loss.item()
@torch.no_grad()
def evaluate(model, data_loader, device):
model.eval()
# 用于存储预测正确的样本个数
sum_num = torch.zeros(1).to(device)
# 在进程0中打印验证进度
if is_main_process():
data_loader = tqdm(data_loader)
for step, data in enumerate(data_loader):
images, labels = data
pred = model(images.to(device))
pred = torch.max(pred, dim=1)[1]
sum_num += torch.eq(pred, labels.to(device)).sum()
# 等待所有进程计算完毕
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
sum_num = reduce_value(sum_num, average=False)
return sum_num.item()