-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathneighbors.py
59 lines (47 loc) · 1.46 KB
/
neighbors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import os.path
import sys
import torch
import torch.nn.functional as F
from annoy import AnnoyIndex
def twod_map(array, mapping):
new_array = [[mapping[j] for j in i] for i in array]
return new_array
def create_index(X, index_type='annoy'):
if index_type == 'faiss':
X_cont = np.ascontiguousarray(X, dtype=np.float32)
n, dim = X_cont.shape
if n < 200000:
index = faiss.IndexFlatIP(dim)
else:
n_cells = 2048
index = faiss.IndexIVFFlat(
faiss.IndexFlatIP(dim),
dim,
n_cells,
faiss.METRIC_INNER_PRODUCT
)
index.nprobe = 16
n_train = min(n, 1000000)
index.train(X_cont[:n_train])
index.add(X_cont)
else:
n, dim = X.size()[0], X.size()[1]
index = AnnoyIndex(dim, metric='angular')
for i,v in enumerate(X):
index.add_item(i, v)
index.build(100)
return index
def find_closest(embeddings, k, index, queries, index_type='annoy'):
if index_type == 'faiss':
points = embeddings[queries]
p = points.detach().numpy()
q = np.ascontiguousarray(p, dtype=np.float32)
neighbors = index.assign(q, k)
else:
neighbors = []
for q in queries:
vector = embeddings[q]
nns = index.get_nns_by_vector(vector, k)
neighbors.append(nns)
return neighbors