-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
updated the strategy design to separate traversal from node selection #152
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,28 @@ | |
from graph_retriever.types import Node | ||
|
||
|
||
class NodeTracker: | ||
"""Helper class for tracking traversal progress.""" | ||
|
||
def __init__(self) -> None: | ||
self.visited_ids: set[str] = set() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should |
||
self.to_traverse: dict[str, Node] = {} | ||
self.selected: dict[str, Node] = {} | ||
|
||
def select(self, nodes: dict[str, Node]) -> None: | ||
"""Select nodes to be included in the result set.""" | ||
self.selected.update(nodes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I guess most use cases would be better off with |
||
|
||
def traverse(self, nodes: dict[str, Node]) -> None: | ||
"""Select nodes to be included in the next traversal.""" | ||
self.to_traverse.update(nodes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: if this doesn't track visited, then we potentially have cases where already visited nodes are added (increasing the traverse set) but not actually be traversed again. It may make the behavior of things easier to reason about if visited nodes are tracked here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. have these return the number of nodes that will actually be traversed or selected. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. empty the |
||
|
||
def select_and_traverse(self, nodes: dict[str, Node]) -> None: | ||
"""Select nodes to be included in the result set and the next traversal.""" | ||
self.select(nodes=nodes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: doesn't need to be named -- just |
||
self.traverse(nodes=nodes) | ||
|
||
|
||
@dataclasses.dataclass(kw_only=True) | ||
class Strategy(abc.ABC): | ||
""" | ||
|
@@ -22,30 +44,34 @@ class Strategy(abc.ABC): | |
Parameters | ||
---------- | ||
k : | ||
Maximum number of nodes to retrieve during traversal. | ||
Maximum number of nodes to select and return during traversal. | ||
start_k : | ||
Number of documents to fetch via similarity for starting the traversal. | ||
Added to any initial roots provided to the traversal. | ||
adjacent_k : | ||
Number of documents to fetch for each outgoing edge. | ||
traverse_k : | ||
Maximum number of nodes to traverse outgoing edges from before returning. | ||
max_depth : | ||
Maximum traversal depth. If `None`, there is no limit. | ||
""" | ||
|
||
k: int = 5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. potentially change to |
||
start_k: int = 4 | ||
adjacent_k: int = 10 | ||
traverse_k: int = 4 # max_traverse? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this isn't currently used anywhere. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I believe this is a new bound which the traversal and/or strategy would need to respect and terminate traversal once the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a limit on the total number of nodes traversed |
||
max_depth: int | None = None | ||
|
||
_query_embedding: list[float] = dataclasses.field(default_factory=list) | ||
|
||
@abc.abstractmethod | ||
def discover_nodes(self, nodes: dict[str, Node]) -> None: | ||
def iteration(self, *, nodes: dict[str, Node], tracker: NodeTracker) -> None: | ||
""" | ||
Add discovered nodes to the strategy. | ||
Process the newly discovered nodes on each iteration. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be worth elaborating on this -- the newly discovered seems important. For instance, this means that if a strategy wants to potentially visit them in a future iteration (such as MMR) it needs to remember these nodes and select/traverse them later. |
||
|
||
This method updates the strategy's state with nodes discovered during | ||
the traversal process. | ||
This method should call `traverse` and/or `select` as appropriate | ||
to update the nodes that need to be traversed in this iteration or | ||
selected at the end of the retrieval, respectively. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -54,43 +80,20 @@ def discover_nodes(self, nodes: dict[str, Node]) -> None: | |
""" | ||
... | ||
|
||
@abc.abstractmethod | ||
def select_nodes(self, *, limit: int) -> Iterable[Node]: | ||
""" | ||
Select discovered nodes to visit in the next iteration. | ||
|
||
This method determines which nodes will be traversed next. If it returns | ||
an empty list, traversal ends even if fewer than `k` nodes have been selected. | ||
|
||
Parameters | ||
---------- | ||
limit : | ||
Maximum number of nodes to select. | ||
|
||
Returns | ||
------- | ||
: | ||
Selected nodes for the next iteration. Traversal ends if this is empty. | ||
""" | ||
... | ||
|
||
def finalize_nodes(self, nodes: Iterable[Node]) -> Iterable[Node]: | ||
def finalize_nodes(self, selected: Iterable[Node]) -> Iterable[Node]: | ||
""" | ||
Finalize the selected nodes. | ||
|
||
This method is called before returning the final set of nodes. | ||
|
||
Parameters | ||
---------- | ||
nodes : | ||
Nodes selected for finalization. | ||
|
||
Returns | ||
------- | ||
: | ||
Finalized nodes. | ||
""" | ||
return nodes | ||
# Take the first `self.k` selected items. | ||
# Strategies may override finalize to perform reranking if needed. | ||
return list(selected)[: self.k] | ||
|
||
@staticmethod | ||
def build( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a property to compute the remaining unselected.