Skip to content

Commit 93d95a8

Browse files
authored
Update _farthest_first_traversal.py
1 parent bd27bac commit 93d95a8

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

src/lobster/data/_farthest_first_traversal.py

+72
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import torch
2+
import edlib
3+
import heapq
4+
import numpy as np
5+
from typing import Optional
26

37

48
class FarthestFirstTraversal:
@@ -77,3 +81,71 @@ def _levenshtein(s1, s2):
7781
previous_row = current_row
7882

7983
return previous_row[-1]
84+
85+
def edit_dist(x: str, y: str):
86+
"""
87+
Computes the edit distance between two strings.
88+
"""
89+
return edlib.align(x, y)["editDistance"]
90+
91+
def ranked_fft(
92+
library: np.ndarray,
93+
ranking_scores: Optional[np.ndarray] = None,
94+
n: int = 2,
95+
descending: bool = False,
96+
):
97+
"""
98+
Farthest-first traversal of a library of strings.
99+
If `ranking_scores` is provided, the scores are used to pick the starting point and break ties.
100+
Args:
101+
library: A numpy array of shape (N,) where N is the number of sequences.
102+
ranking_scores: A numpy array of shape (N,) containing the ranking scores of the sequences in the library.
103+
n: The number of sequences to return.
104+
Returns:
105+
A numpy array of shape (n,) containing the indices of the selected sequences.
106+
"""
107+
if ranking_scores is None:
108+
ranking_scores = np.zeros(library.shape[0])
109+
remaining_indices = list(range(library.shape[0]))
110+
else:
111+
if descending:
112+
ranking_scores = -ranking_scores
113+
remaining_indices = list(np.argsort(ranking_scores))
114+
115+
selected = [remaining_indices.pop(0)]
116+
117+
if n == 1:
118+
return np.array(selected)
119+
120+
pq = []
121+
# First pass through library
122+
for index in remaining_indices:
123+
# Pushing with heapq, negate dist to simulate max-heap with min-heap
124+
(
125+
heapq.heappush(
126+
pq,
127+
(
128+
-edit_dist(library[index], library[selected[0]]),
129+
ranking_scores[index],
130+
index,
131+
1,
132+
),
133+
),
134+
)
135+
136+
for _ in range(1, n):
137+
while True:
138+
neg_dist, score, idx, num_checked = heapq.heappop(pq)
139+
# Check if the top of the heap has been checked against all currently selected sequences
140+
if num_checked < len(selected):
141+
min_dist = min(
142+
edit_dist(library[idx], library[selected[i]])
143+
for i in range(num_checked, len(selected))
144+
)
145+
min_dist = min(min_dist, -neg_dist)
146+
heapq.heappush(pq, (-min_dist, score, idx, len(selected)))
147+
else:
148+
selected.append(idx)
149+
break
150+
151+
return np.array(selected)

0 commit comments

Comments
 (0)