-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
updated the strategy design to separate traversal from node selection #152
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,41 @@ | |
from graph_retriever.types import Node | ||
|
||
|
||
class NodeTracker: | ||
"""Helper class for tracking traversal progress.""" | ||
|
||
def __init__(self, select_k: int, max_depth: int| None) -> None: | ||
self._select_k: int = select_k | ||
self._max_depth: int | None = max_depth | ||
self._visited_nodes: set[int] = set() | ||
self.to_traverse: dict[str, Node] = {} | ||
self.selected: dict[str, Node] = {} | ||
|
||
@property | ||
def remaining(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe |
||
"""The remaining number of nodes to be selected""" | ||
return max(self._select_k - len(self.selected), 0) | ||
|
||
def select(self, nodes: dict[str, Node]) -> None: | ||
"""Select nodes to be included in the result set.""" | ||
self.selected.update(nodes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I guess most use cases would be better off with |
||
|
||
def traverse(self, nodes: dict[str, Node]) -> int: | ||
"""Select nodes to be included in the next traversal.""" | ||
for id, node in nodes.items(): | ||
if id in self._visited_nodes: | ||
continue | ||
if self._max_depth is None or node.depth <= self._max_depth: | ||
continue | ||
self.to_traverse[id] = node | ||
return len(self.to_traverse) | ||
|
||
def select_and_traverse(self, nodes: dict[str, Node]) -> None: | ||
"""Select nodes to be included in the result set and the next traversal.""" | ||
self.select(nodes) | ||
self.traverse(nodes) | ||
|
||
|
||
@dataclasses.dataclass(kw_only=True) | ||
class Strategy(abc.ABC): | ||
""" | ||
|
@@ -21,31 +56,36 @@ class Strategy(abc.ABC): | |
|
||
Parameters | ||
---------- | ||
k : | ||
Maximum number of nodes to retrieve during traversal. | ||
select_k : | ||
Maximum number of nodes to select and return during traversal. | ||
start_k : | ||
Number of documents to fetch via similarity for starting the traversal. | ||
Number of nodes to fetch via similarity for starting the traversal. | ||
Added to any initial roots provided to the traversal. | ||
adjacent_k : | ||
Number of documents to fetch for each outgoing edge. | ||
Number of nodes to fetch for each outgoing edge. | ||
max_traverse : | ||
Maximum number of nodes to traverse outgoing edges from before returning. | ||
If `None`, there is no limit. | ||
max_depth : | ||
Maximum traversal depth. If `None`, there is no limit. | ||
""" | ||
|
||
k: int = 5 | ||
select_k: int = 5 | ||
start_k: int = 4 | ||
adjacent_k: int = 10 | ||
max_traverse: int | None = None | ||
max_depth: int | None = None | ||
|
||
_query_embedding: list[float] = dataclasses.field(default_factory=list) | ||
|
||
@abc.abstractmethod | ||
def discover_nodes(self, nodes: dict[str, Node]) -> None: | ||
def iteration(self, *, nodes: dict[str, Node], tracker: NodeTracker) -> None: | ||
""" | ||
Add discovered nodes to the strategy. | ||
Process the newly discovered nodes on each iteration. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be worth elaborating on this -- the newly discovered seems important. For instance, this means that if a strategy wants to potentially visit them in a future iteration (such as MMR) it needs to remember these nodes and select/traverse them later. |
||
|
||
This method updates the strategy's state with nodes discovered during | ||
the traversal process. | ||
This method should call `traverse` and/or `select` as appropriate | ||
to update the nodes that need to be traversed in this iteration or | ||
selected at the end of the retrieval, respectively. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -54,43 +94,20 @@ def discover_nodes(self, nodes: dict[str, Node]) -> None: | |
""" | ||
... | ||
|
||
@abc.abstractmethod | ||
def select_nodes(self, *, limit: int) -> Iterable[Node]: | ||
""" | ||
Select discovered nodes to visit in the next iteration. | ||
|
||
This method determines which nodes will be traversed next. If it returns | ||
an empty list, traversal ends even if fewer than `k` nodes have been selected. | ||
|
||
Parameters | ||
---------- | ||
limit : | ||
Maximum number of nodes to select. | ||
|
||
Returns | ||
------- | ||
: | ||
Selected nodes for the next iteration. Traversal ends if this is empty. | ||
""" | ||
... | ||
|
||
def finalize_nodes(self, nodes: Iterable[Node]) -> Iterable[Node]: | ||
def finalize_nodes(self, selected: Iterable[Node]) -> Iterable[Node]: | ||
""" | ||
Finalize the selected nodes. | ||
|
||
This method is called before returning the final set of nodes. | ||
|
||
Parameters | ||
---------- | ||
nodes : | ||
Nodes selected for finalization. | ||
|
||
Returns | ||
------- | ||
: | ||
Finalized nodes. | ||
""" | ||
return nodes | ||
# Take the first `self.k` selected items. | ||
# Strategies may override finalize to perform reranking if needed. | ||
return list(selected)[: self.select_k] | ||
|
||
@staticmethod | ||
def build( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a property to compute the remaining unselected.