forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_freq_prior.py
116 lines (93 loc) · 3.08 KB
/
train_freq_prior.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import json
import os
import pickle
import numpy as np
def parse_args():
parser = argparse.ArgumentParser(
description="Train the Frequenct Prior For RelDN."
)
parser.add_argument(
"--overlap", action="store_true", help="Only count overlap boxes."
)
parser.add_argument(
"--json-path",
type=str,
default="~/.mxnet/datasets/visualgenome",
help="Only count overlap boxes.",
)
args = parser.parse_args()
return args
args = parse_args()
use_overlap = args.overlap
PATH_TO_DATASETS = os.path.expanduser(args.json_path)
path_to_json = os.path.join(PATH_TO_DATASETS, "rel_annotations_train.json")
# format in y1y2x1x2
def with_overlap(boxA, boxB):
xA = max(boxA[2], boxB[2])
xB = min(boxA[3], boxB[3])
if xB > xA:
yA = max(boxA[0], boxB[0])
yB = min(boxA[1], boxB[1])
if yB > yA:
return 1
return 0
def box_ious(boxes):
n = len(boxes)
res = np.zeros((n, n))
for i in range(n - 1):
for j in range(i + 1, n):
iou_val = with_overlap(boxes[i], boxes[j])
res[i, j] = iou_val
res[j, i] = iou_val
return res
with open(path_to_json, "r") as f:
tmp = f.read()
train_data = json.loads(tmp)
fg_matrix = np.zeros((150, 150, 51), dtype=np.int64)
bg_matrix = np.zeros((150, 150), dtype=np.int64)
for _, item in train_data.items():
gt_box_to_label = {}
for rel in item:
sub_bbox = rel["subject"]["bbox"]
ob_bbox = rel["object"]["bbox"]
sub_class = rel["subject"]["category"]
ob_class = rel["object"]["category"]
rel_class = rel["predicate"]
sub_node = tuple(sub_bbox)
ob_node = tuple(ob_bbox)
if sub_node not in gt_box_to_label:
gt_box_to_label[sub_node] = sub_class
if ob_node not in gt_box_to_label:
gt_box_to_label[ob_node] = ob_class
fg_matrix[sub_class, ob_class, rel_class + 1] += 1
if use_overlap:
gt_boxes = [*gt_box_to_label]
gt_classes = np.array([*gt_box_to_label.values()])
iou_mat = box_ious(gt_boxes)
cols, rows = np.where(iou_mat)
if len(cols) and len(rows):
for col, row in zip(cols, rows):
bg_matrix[gt_classes[col], gt_classes[row]] += 1
else:
all_possib = np.ones_like(iou_mat, dtype=np.bool)
np.fill_diagonal(all_possib, 0)
cols, rows = np.where(all_possib)
for col, row in zip(cols, rows):
bg_matrix[gt_classes[col], gt_classes[row]] += 1
else:
for b1, l1 in gt_box_to_label.items():
for b2, l2 in gt_box_to_label.items():
if b1 == b2:
continue
bg_matrix[l1, l2] += 1
eps = 1e-3
bg_matrix += 1
fg_matrix[:, :, 0] = bg_matrix
pred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps)
if use_overlap:
with open("freq_prior_overlap.pkl", "wb") as f:
pickle.dump(pred_dist, f)
else:
with open("freq_prior.pkl", "wb") as f:
pickle.dump(pred_dist, f)