-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpredict.py
30 lines (25 loc) · 1004 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import argparse
from model import Net
# from batch_model import Net
from utils import to_device
from data import ARESdataset
from torch.utils.data import DataLoader
def main(args):
dataset = ARESdataset(args.dir)
dataloader = DataLoader(dataset,batch_size=1,shuffle=True)
net = Net(device=args.device)
net.load_state_dict(torch.load(args.model_path, map_location=args.device))
net.eval()
with torch.no_grad():
for batch in dataloader:
V,atoms_info,rms,atoms_lens = (to_device(x,args.device) for x in batch)
out = net(V, atoms_info, atoms_lens)
print(f'out:{out.item()} rms:{rms.item()} gap:{(out-rms).abs().item()}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, default='data/val')
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--model_path', type=str, required=True)
args = parser.parse_args()
main(args)