|
1 | 1 | import torch
|
| 2 | +import edlib |
| 3 | +import heapq |
| 4 | +import numpy as np |
| 5 | +from typing import Optional |
2 | 6 |
|
3 | 7 |
|
4 | 8 | class FarthestFirstTraversal:
|
@@ -77,3 +81,71 @@ def _levenshtein(s1, s2):
|
77 | 81 | previous_row = current_row
|
78 | 82 |
|
79 | 83 | return previous_row[-1]
|
| 84 | + |
| 85 | +def edit_dist(x: str, y: str): |
| 86 | + """ |
| 87 | + Computes the edit distance between two strings. |
| 88 | + """ |
| 89 | + return edlib.align(x, y)["editDistance"] |
| 90 | + |
| 91 | +def ranked_fft( |
| 92 | + library: np.ndarray, |
| 93 | + ranking_scores: Optional[np.ndarray] = None, |
| 94 | + n: int = 2, |
| 95 | + descending: bool = False, |
| 96 | +): |
| 97 | + """ |
| 98 | + Farthest-first traversal of a library of strings. |
| 99 | + If `ranking_scores` is provided, the scores are used to pick the starting point and break ties. |
| 100 | + Args: |
| 101 | + library: A numpy array of shape (N,) where N is the number of sequences. |
| 102 | + ranking_scores: A numpy array of shape (N,) containing the ranking scores of the sequences in the library. |
| 103 | + n: The number of sequences to return. |
| 104 | + Returns: |
| 105 | + A numpy array of shape (n,) containing the indices of the selected sequences. |
| 106 | + """ |
| 107 | + if ranking_scores is None: |
| 108 | + ranking_scores = np.zeros(library.shape[0]) |
| 109 | + remaining_indices = list(range(library.shape[0])) |
| 110 | + else: |
| 111 | + if descending: |
| 112 | + ranking_scores = -ranking_scores |
| 113 | + remaining_indices = list(np.argsort(ranking_scores)) |
| 114 | + |
| 115 | + selected = [remaining_indices.pop(0)] |
| 116 | + |
| 117 | + if n == 1: |
| 118 | + return np.array(selected) |
| 119 | + |
| 120 | + pq = [] |
| 121 | + # First pass through library |
| 122 | + for index in remaining_indices: |
| 123 | + # Pushing with heapq, negate dist to simulate max-heap with min-heap |
| 124 | + ( |
| 125 | + heapq.heappush( |
| 126 | + pq, |
| 127 | + ( |
| 128 | + -edit_dist(library[index], library[selected[0]]), |
| 129 | + ranking_scores[index], |
| 130 | + index, |
| 131 | + 1, |
| 132 | + ), |
| 133 | + ), |
| 134 | + ) |
| 135 | + |
| 136 | + for _ in range(1, n): |
| 137 | + while True: |
| 138 | + neg_dist, score, idx, num_checked = heapq.heappop(pq) |
| 139 | + # Check if the top of the heap has been checked against all currently selected sequences |
| 140 | + if num_checked < len(selected): |
| 141 | + min_dist = min( |
| 142 | + edit_dist(library[idx], library[selected[i]]) |
| 143 | + for i in range(num_checked, len(selected)) |
| 144 | + ) |
| 145 | + min_dist = min(min_dist, -neg_dist) |
| 146 | + heapq.heappush(pq, (-min_dist, score, idx, len(selected))) |
| 147 | + else: |
| 148 | + selected.append(idx) |
| 149 | + break |
| 150 | + |
| 151 | + return np.array(selected) |
0 commit comments