-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy patheval.py
56 lines (42 loc) · 1.51 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python3
import sys
from tensorflow.python.keras import backend
import tensorflow as tf
from tqdm import tqdm
import numpy as np
import datasets
import network
import metrics
import argparse
from utils import write_spring_predictions, make_spring_folder
def eval(checkpoint, data, outstream):
# construct model
with backend.get_graph().as_default():
net = network.Network()
# load weights
net.load_weights(checkpoint)
# prepare metrics
sf_metric = metrics.SceneFlowMetrics()
# make predictions
for e,(images, gt) in enumerate(data):
print("Evaluating sequence %d..." % e)
# predict scene flow
res = net(inputs=images)
# update metrics
sf_metric.update_state(gt, res)
sf_metric.print(stream=outstream)
if __name__ == "__main__":
arg = sys.argv[1]
if 'kitti' in arg:
data = datasets.get_kitti_dataset(datasets.KITTI_VALIDATION_IDXS, batch_size=1)
eval(arg, data, outstream=sys.stdout)
elif 'spring' in arg:
root = datasets.BASEPATH_SPRING
split = 'train'
scene_dict = datasets.prepare_spring_data_dict(root, split)
_, indices = datasets.split_spring_seq(root, split)
test_dataset = datasets.SpringDataset(root, indices, scene_dict, split)
data = datasets.get_spring_dataset(test_dataset, 1, split=split)
eval(arg, data, outstream=sys.stdout)
else:
raise ValueError(f"checkpoint name should have 'spring' or 'kitti' but given {arg}")