Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated the strategy design to separate traversal from node selection #152

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions packages/graph-retriever/src/graph_retriever/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,36 @@
class NodeTracker:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a property to compute the remaining unselected.

"""Helper class for tracking traversal progress."""

def __init__(self) -> None:
self.visited_ids: set[str] = set()
def __init__(self, select_k: int, max_depth: int| None) -> None:
self._select_k: int = select_k
self._max_depth: int | None = max_depth
self._visited_nodes: set[int] = set()
self.to_traverse: dict[str, Node] = {}
self.selected: dict[str, Node] = {}

@property
def remaining(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe num_remaining or something like that to emphasize it's a number? And maybe mentioning selected (num_unselected?)

"""The remaining number of nodes to be selected"""
return max(self._select_k - len(self.selected), 0)

def select(self, nodes: dict[str, Node]) -> None:
"""Select nodes to be included in the result set."""
self.selected.update(nodes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should select return the number of new nodes (or the number of total nodes)?

I guess most use cases would be better off with len(tracker.selected). Although, we could have a @property def num_selected(self) -> int?


def traverse(self, nodes: dict[str, Node]) -> None:
def traverse(self, nodes: dict[str, Node]) -> int:
"""Select nodes to be included in the next traversal."""
self.to_traverse.update(nodes)
for id, node in nodes.items():
if id in self._visited_nodes:
continue
if self._max_depth is None or node.depth <= self._max_depth:
continue
self.to_traverse[id] = node
return len(self.to_traverse)

def select_and_traverse(self, nodes: dict[str, Node]) -> None:
"""Select nodes to be included in the result set and the next traversal."""
self.select(nodes=nodes)
self.traverse(nodes=nodes)
self.select(nodes)
self.traverse(nodes)


@dataclasses.dataclass(kw_only=True)
Expand All @@ -43,23 +56,24 @@ class Strategy(abc.ABC):

Parameters
----------
k :
select_k :
Maximum number of nodes to select and return during traversal.
start_k :
Number of documents to fetch via similarity for starting the traversal.
Number of nodes to fetch via similarity for starting the traversal.
Added to any initial roots provided to the traversal.
adjacent_k :
Number of documents to fetch for each outgoing edge.
traverse_k :
Number of nodes to fetch for each outgoing edge.
max_traverse :
Maximum number of nodes to traverse outgoing edges from before returning.
If `None`, there is no limit.
max_depth :
Maximum traversal depth. If `None`, there is no limit.
"""

k: int = 5
select_k: int = 5
start_k: int = 4
adjacent_k: int = 10
traverse_k: int = 4 # max_traverse?
max_traverse: int | None = None
max_depth: int | None = None

_query_embedding: list[float] = dataclasses.field(default_factory=list)
Expand Down Expand Up @@ -93,7 +107,7 @@ def finalize_nodes(self, selected: Iterable[Node]) -> Iterable[Node]:
"""
# Take the first `self.k` selected items.
# Strategies may override finalize to perform reranking if needed.
return list(selected)[: self.k]
return list(selected)[: self.select_k]

@staticmethod
def build(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Mmr(Strategy):

Parameters
----------
k :
select_k :
Maximum number of nodes to retrieve during traversal.
start_k :
Number of documents to fetch via similarity for starting the traversal.
Expand Down Expand Up @@ -111,7 +111,7 @@ def _selected_embeddings(self) -> NDArray[np.float32]:
NDArray[np.float32]
(N, dim) ndarray with a row for each selected node.
"""
return np.ndarray((self.k, self._dimensions), dtype=np.float32)
return np.ndarray((self.select_k, self._dimensions), dtype=np.float32)

@cached_property
def _candidate_embeddings(self) -> NDArray[np.float32]:
Expand Down
134 changes: 30 additions & 104 deletions packages/graph-retriever/src/graph_retriever/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ def __init__(
self._used = False
self._visited_edges: set[Edge] = set()
self._edge_depths: dict[Edge, int] = {}
self._existing_nodes: set[str] = set()
self._node_tracker: NodeTracker = NodeTracker()
self._node_tracker: NodeTracker = NodeTracker(select_k=strategy.select_k, max_depth=strategy.max_depth)

def _check_first_use(self):
assert not self._used, "Traversals cannot be re-used."
Expand All @@ -178,22 +177,19 @@ def traverse(self) -> list[Node]:
self._check_first_use()

# Retrieve initial candidates.
initial_content = self._fetch_initial_candidates()
content = self._fetch_initial_candidates()
if self.initial_root_ids:
initial_content.extend(self.store.get(self.initial_root_ids))

self.iteration(initial_content, depth=0)
content.extend(self.store.get(self.initial_root_ids))

while True:
next_outgoing_edges = self.select_next_edges()
if next_outgoing_edges is None:
nodes = [self._content_to_node(c, depth=0) for c in content]
self.strategy.iteration(nodes={n.id: n for n in nodes}, tracker=self._node_tracker)

if self._node_tracker.remaining == 0 or len(self._node_tracker.to_traverse) == 0:
break
elif next_outgoing_edges:
# Find the (new) document with incoming edges from those edges.
adjacent_content = self._fetch_adjacent(next_outgoing_edges)
self.iteration(adjacent_content)
else:
self.iteration({})

next_outgoing_edges = self.select_next_edges(self._node_tracker.to_traverse)
content = self._fetch_adjacent(next_outgoing_edges)

return self.finish()

Expand All @@ -212,23 +208,19 @@ async def atraverse(self) -> list[Node]:
self._check_first_use()

# Retrieve initial candidates.
initial_content = await self._afetch_initial_candidates()

content = await self._afetch_initial_candidates()
if self.initial_root_ids:
initial_content.extend(await self.store.aget(self.initial_root_ids))

self.iteration(initial_content, depth=0)
content.extend(await self.store.aget(self.initial_root_ids))

while True:
next_outgoing_edges = self.select_next_edges()
if next_outgoing_edges is None:
nodes = [self._content_to_node(c, depth=0) for c in content]
self.strategy.iteration(nodes={n.id: n for n in nodes}, tracker=self._node_tracker)

if self._node_tracker.remaining == 0 or len(self._node_tracker.to_traverse) == 0:
break
elif next_outgoing_edges:
# Find the (new) content with incoming edges from those edges.
adjacent_content = await self._afetch_adjacent(next_outgoing_edges)
self.iteration(adjacent_content)
else:
self.iteration({})

next_outgoing_edges = self.select_next_edges(self._node_tracker.to_traverse)
content = await self._afetch_adjacent(next_outgoing_edges)

return self.finish()

Expand Down Expand Up @@ -310,15 +302,14 @@ async def _afetch_adjacent(self, edges: set[Edge]) -> Iterable[Content]:
**self.store_kwargs,
)

def _content_to_new_node(
def _content_to_node(
self, content: Content, *, depth: int | None = None
) -> Node | None:
"""
Convert a content into a new node for the traversal.
Converts a content object into a node for traversal.

This method checks whether the document has already been processed. If not,
it creates a new `Node` instance, associates it with the document's metadata,
and calculates its depth based on the incoming edges.
This method creates a new `Node` instance, associates it with the document's
metadata, and calculates its depth based on the incoming edges.

Parameters
----------
Expand All @@ -334,8 +325,6 @@ def _content_to_new_node(
The newly created node, or None if the document has already been
processed.
"""
if content.id in self._existing_nodes:
return None

# Determine incoming/outgoing edges.
edges = self.edge_function(content)
Expand All @@ -351,7 +340,7 @@ def _content_to_new_node(
default=0,
)

node = Node(
return Node(
id=content.id,
content=content.content,
depth=depth,
Expand All @@ -361,52 +350,18 @@ def _content_to_new_node(
outgoing_edges=edges.outgoing,
)

self._existing_nodes.add(node.id)

return node

def iteration(
self, contents: Iterable[Content], *, depth: int | None = None
) -> None:
def select_next_edges(self, nodes: Iterable[Node]) -> set[Edge]:
"""
Convert a bunch of content into nodes and send them to strategy iteration.
Find the unvisited outgoing edges from the set of new nodes to traverse.

This method records the depth of new nodes, filters them based on the
strategy's maximum depth and sends the filtered set to the next strategy
iteration.

Parameters
----------
contents :
The contents to add.
depth :
The depth to assign to the nodes. If None, the depth is inferred
based on the incoming edges.
"""
# Record the depth of new nodes.
nodes = {
node.id: node
for c in contents
if (node := self._content_to_new_node(c, depth=depth)) is not None
if (
self.strategy.max_depth is None or node.depth <= self.strategy.max_depth
)
}
self.strategy.iteration(nodes=nodes, tracker=self._node_tracker)

def visit_nodes(self, nodes: Iterable[Node]) -> set[Edge]:
"""
Mark nodes as visited and return their new outgoing edges.

This method updates the traversal state by marking the provided nodes as visited
and recording their outgoing edges. Outgoing edges that have not been visited
before are identified and added to the set of edges to explore in subsequent
traversal steps.
This method updates the traversal state by recording the outgoing edges of the
provided nodes. Outgoing edges that have not been visited before are identified
and added to the set of edges to explore in subsequent traversal steps.

Parameters
----------
nodes :
The nodes to mark as visited.
The new nodes to traverse

Returns
-------
Expand Down Expand Up @@ -434,37 +389,8 @@ def visit_nodes(self, nodes: Iterable[Node]) -> set[Edge]:

new_outgoing_edge_set = set(new_outgoing_edges.keys())
self._visited_edges.update(new_outgoing_edge_set)
self._node_tracker.visited_ids.update([n.id for n in nodes])
return new_outgoing_edge_set

def select_next_edges(self) -> set[Edge] | None:
"""
Select the next set of edges to explore.

This method uses the node tracker to select the next batch of nodes
and identifies new outgoing edges for exploration.

Returns
-------
:
The set of new edges to explore, or None if the traversal is
complete.
"""
remaining = self.strategy.k - len(self._node_tracker.selected)
if remaining <= 0:
return None

next_nodes = self._node_tracker.to_traverse.values()
if not next_nodes:
return None

filtered_nodes = [
n for n in next_nodes if n.id not in self._node_tracker.visited_ids
]
if len(filtered_nodes) == 0:
return None

return self.visit_nodes(filtered_nodes)

def finish(self) -> list[Node]:
"""
Expand Down
32 changes: 16 additions & 16 deletions packages/graph-retriever/tests/strategies/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@


def test_build_strategy_base():
base_strategy = Eager(k=6, start_k=5, adjacent_k=9, max_depth=2)
base_strategy = Eager(select_k=6, start_k=5, adjacent_k=9, max_depth=2)

# base strategy with no changes
strategy = Strategy.build(base_strategy=base_strategy)
assert strategy == base_strategy

# base strategy with changed k
strategy = Strategy.build(base_strategy=base_strategy, k=7)
assert strategy == Eager(k=7, start_k=5, adjacent_k=9, max_depth=2)
strategy = Strategy.build(base_strategy=base_strategy, select_k=7)
assert strategy == Eager(select_k=7, start_k=5, adjacent_k=9, max_depth=2)

# base strategy with invalid kwarg
with pytest.raises(
Expand All @@ -28,20 +28,20 @@ def test_build_strategy_base():


def test_build_strategy_base_override():
base_strategy = Eager(k=6, start_k=5, adjacent_k=9, max_depth=2)
override_strategy = Eager(k=7, start_k=4, adjacent_k=8, max_depth=3)
base_strategy = Eager(select_k=6, start_k=5, adjacent_k=9, max_depth=2)
override_strategy = Eager(select_k=7, start_k=4, adjacent_k=8, max_depth=3)

# override base strategy
strategy = Strategy.build(
base_strategy=base_strategy, strategy=override_strategy, k=4
base_strategy=base_strategy, strategy=override_strategy, select_k=4
)
assert strategy == dataclasses.replace(override_strategy, k=4)
assert strategy == dataclasses.replace(override_strategy, select_k=4)

# override base strategy and change params
strategy = Strategy.build(
base_strategy=base_strategy, strategy=override_strategy, k=3, adjacent_k=7
base_strategy=base_strategy, strategy=override_strategy, select_k=3, adjacent_k=7
)
assert strategy == Eager(k=3, start_k=4, adjacent_k=7, max_depth=3)
assert strategy == Eager(select_k=3, start_k=4, adjacent_k=7, max_depth=3)

# override base strategy and invalid kwarg
with pytest.raises(
Expand All @@ -50,7 +50,7 @@ def test_build_strategy_base_override():
strategy = Strategy.build(
base_strategy=base_strategy,
strategy=override_strategy,
k=4,
select_k=4,
invalid_kwarg=4,
)

Expand All @@ -63,8 +63,8 @@ def test_build_strategy_base_override():


def test_build_strategy_base_override_mmr():
base_strategy = Eager(k=6, start_k=5, adjacent_k=9, max_depth=2)
override_strategy = Mmr(k=7, start_k=4, adjacent_k=8, max_depth=3, lambda_mult=0.3)
base_strategy = Eager(select_k=6, start_k=5, adjacent_k=9, max_depth=2)
override_strategy = Mmr(select_k=7, start_k=4, adjacent_k=8, max_depth=3, lambda_mult=0.3)

# override base strategy with mmr kwarg
with pytest.raises(
Expand All @@ -75,15 +75,15 @@ def test_build_strategy_base_override_mmr():

# override base strategy with mmr strategy
strategy = Strategy.build(
base_strategy=base_strategy, strategy=override_strategy, k=4
base_strategy=base_strategy, strategy=override_strategy, select_k=4
)
assert strategy == dataclasses.replace(override_strategy, k=4)
assert strategy == dataclasses.replace(override_strategy, select_k=4)

# override base strategy with mmr strategy and mmr arg
strategy = Strategy.build(
base_strategy=base_strategy, strategy=override_strategy, k=4, lambda_mult=0.2
base_strategy=base_strategy, strategy=override_strategy, select_k=4, lambda_mult=0.2
)
assert strategy == Mmr(k=4, start_k=4, adjacent_k=8, max_depth=3, lambda_mult=0.2)
assert strategy == Mmr(select_k=4, start_k=4, adjacent_k=8, max_depth=3, lambda_mult=0.2)

# start with override strategy, change to base, try to set mmr arg
with pytest.raises(
Expand Down
Loading
Loading