Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated the strategy design to separate traversal from node selection #152

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 35 additions & 32 deletions packages/graph-retriever/src/graph_retriever/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,28 @@
from graph_retriever.types import Node


class NodeTracker:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a property to compute the remaining unselected.

"""Helper class for tracking traversal progress."""

def __init__(self) -> None:
self.visited_ids: set[str] = set()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should visited_ids be removed for now? Or should more tracking logic move into here.

self.to_traverse: dict[str, Node] = {}
self.selected: dict[str, Node] = {}

def select(self, nodes: dict[str, Node]) -> None:
"""Select nodes to be included in the result set."""
self.selected.update(nodes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should select return the number of new nodes (or the number of total nodes)?

I guess most use cases would be better off with len(tracker.selected). Although, we could have a @property def num_selected(self) -> int?


def traverse(self, nodes: dict[str, Node]) -> None:
"""Select nodes to be included in the next traversal."""
self.to_traverse.update(nodes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: if this doesn't track visited, then we potentially have cases where already visited nodes are added (increasing the traverse set) but not actually be traversed again. It may make the behavior of things easier to reason about if visited nodes are tracked here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have these return the number of nodes that will actually be traversed or selected.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty the to_traverse queue at the end of an interation


def select_and_traverse(self, nodes: dict[str, Node]) -> None:
"""Select nodes to be included in the result set and the next traversal."""
self.select(nodes=nodes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: doesn't need to be named -- just self.select(nodes).

self.traverse(nodes=nodes)


@dataclasses.dataclass(kw_only=True)
class Strategy(abc.ABC):
"""
Expand All @@ -22,30 +44,34 @@ class Strategy(abc.ABC):
Parameters
----------
k :
Maximum number of nodes to retrieve during traversal.
Maximum number of nodes to select and return during traversal.
start_k :
Number of documents to fetch via similarity for starting the traversal.
Added to any initial roots provided to the traversal.
adjacent_k :
Number of documents to fetch for each outgoing edge.
traverse_k :
Maximum number of nodes to traverse outgoing edges from before returning.
max_depth :
Maximum traversal depth. If `None`, there is no limit.
"""

k: int = 5
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially change to select_k

start_k: int = 4
adjacent_k: int = 10
traverse_k: int = 4 # max_traverse?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't currently used anywhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I believe this is a new bound which the traversal and/or strategy would need to respect and terminate traversal once the len(tracker.visited_ids) > traverse_k or something.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a limit on the total number of nodes traversed

max_depth: int | None = None

_query_embedding: list[float] = dataclasses.field(default_factory=list)

@abc.abstractmethod
def discover_nodes(self, nodes: dict[str, Node]) -> None:
def iteration(self, *, nodes: dict[str, Node], tracker: NodeTracker) -> None:
"""
Add discovered nodes to the strategy.
Process the newly discovered nodes on each iteration.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be worth elaborating on this -- the newly discovered seems important. For instance, this means that if a strategy wants to potentially visit them in a future iteration (such as MMR) it needs to remember these nodes and select/traverse them later.


This method updates the strategy's state with nodes discovered during
the traversal process.
This method should call `traverse` and/or `select` as appropriate
to update the nodes that need to be traversed in this iteration or
selected at the end of the retrieval, respectively.

Parameters
----------
Expand All @@ -54,43 +80,20 @@ def discover_nodes(self, nodes: dict[str, Node]) -> None:
"""
...

@abc.abstractmethod
def select_nodes(self, *, limit: int) -> Iterable[Node]:
"""
Select discovered nodes to visit in the next iteration.

This method determines which nodes will be traversed next. If it returns
an empty list, traversal ends even if fewer than `k` nodes have been selected.

Parameters
----------
limit :
Maximum number of nodes to select.

Returns
-------
:
Selected nodes for the next iteration. Traversal ends if this is empty.
"""
...

def finalize_nodes(self, nodes: Iterable[Node]) -> Iterable[Node]:
def finalize_nodes(self, selected: Iterable[Node]) -> Iterable[Node]:
"""
Finalize the selected nodes.

This method is called before returning the final set of nodes.

Parameters
----------
nodes :
Nodes selected for finalization.

Returns
-------
:
Finalized nodes.
"""
return nodes
# Take the first `self.k` selected items.
# Strategies may override finalize to perform reranking if needed.
return list(selected)[: self.k]

@staticmethod
def build(
Expand Down
15 changes: 3 additions & 12 deletions packages/graph-retriever/src/graph_retriever/strategies/eager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Provide eager (breadth-first) traversal strategy."""

import dataclasses
from collections.abc import Iterable

from typing_extensions import override

from graph_retriever.strategies.base import Strategy
from graph_retriever.strategies.base import NodeTracker, Strategy
from graph_retriever.types import Node


Expand All @@ -32,14 +31,6 @@ class Eager(Strategy):
Maximum traversal depth. If `None`, there is no limit.
"""

_nodes: list[Node] = dataclasses.field(default_factory=list)

@override
def discover_nodes(self, nodes: dict[str, Node]) -> None:
self._nodes.extend(nodes.values())

@override
def select_nodes(self, *, limit: int) -> Iterable[Node]:
nodes = self._nodes[:limit]
self._nodes = []
return nodes
def iteration(self, nodes: dict[str, Node], tracker: NodeTracker) -> None:
tracker.select_and_traverse(nodes)
105 changes: 48 additions & 57 deletions packages/graph-retriever/src/graph_retriever/strategies/mmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numpy.typing import NDArray
from typing_extensions import override

from graph_retriever.strategies.base import Strategy
from graph_retriever.strategies.base import NodeTracker, Strategy
from graph_retriever.types import Node
from graph_retriever.utils.math import cosine_similarity

Expand Down Expand Up @@ -203,8 +203,7 @@ def _pop_candidate(

return candidate, embedding

@override
def select_nodes(self, *, limit: int) -> Iterable[Node]:
def _next(self) -> dict[str, Node]:
"""
Select and pop the best item being considered.

Expand All @@ -214,10 +213,8 @@ def select_nodes(self, *, limit: int) -> Iterable[Node]:
-------
A tuple containing the ID of the best item.
"""
if limit == 0:
return []
if self._best_id is None or self._best_score < self.min_mmr_score:
return []
return {}

# Get the selection and remove from candidates.
selected_id = self._best_id
Expand Down Expand Up @@ -250,61 +247,55 @@ def select_nodes(self, *, limit: int) -> Iterable[Node]:
self._best_score = candidate.score
self._best_id = candidate.node.id

return [selected_node]
return {selected_node.id: selected_node}

@override
def discover_nodes(self, nodes: dict[str, Node]) -> None:
def iteration(self, nodes: dict[str, Node], tracker: NodeTracker) -> None:
"""Add candidates to the consideration set."""
# Determine the keys to actually include.
# These are the candidates that aren't already selected
# or under consideration.

include_ids_set = set(nodes.keys())
include_ids_set.difference_update(self._selected_ids)
include_ids_set.difference_update(self._candidate_id_to_index.keys())
include_ids = list(include_ids_set)

# Now, build up a matrix of the remaining candidate embeddings.
# And add them to the
new_embeddings: NDArray[np.float32] = np.ndarray(
(
len(include_ids),
self._dimensions,
if len(nodes) > 0:
# Build up a matrix of the remaining candidate embeddings.
# And add them to the candidate set
new_embeddings: NDArray[np.float32] = np.ndarray(
(
len(nodes),
self._dimensions,
)
)
)
offset = self._candidate_embeddings.shape[0]
for index, candidate_id in enumerate(include_ids):
self._candidate_id_to_index[candidate_id] = offset + index
new_embeddings[index] = nodes[candidate_id].embedding

# Compute the similarity to the query.
similarity = cosine_similarity(new_embeddings, self._nd_query_embedding)

# Compute the distance metrics of all of pairs in the selected set with
# the new candidates.
redundancy = cosine_similarity(
new_embeddings, self._already_selected_embeddings()
)
for index, candidate_id in enumerate(include_ids):
max_redundancy = 0.0
if redundancy.shape[0] > 0:
max_redundancy = redundancy[index].max()
candidate = _MmrCandidate(
node=nodes[candidate_id],
similarity=similarity[index][0],
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self._lambda_mult_complement * max_redundancy,
offset = self._candidate_embeddings.shape[0]
for index, candidate_id in enumerate(nodes.keys()):
self._candidate_id_to_index[candidate_id] = offset + index
new_embeddings[index] = nodes[candidate_id].embedding

# Compute the similarity to the query.
similarity = cosine_similarity(new_embeddings, self._nd_query_embedding)

# Compute the distance metrics of all of pairs in the selected set with
# the new candidates.
redundancy = cosine_similarity(
new_embeddings, self._already_selected_embeddings()
)
self._candidates.append(candidate)

if candidate.score >= self._best_score:
self._best_score = candidate.score
self._best_id = candidate.node.id
for index, candidate_id in enumerate(nodes.keys()):
max_redundancy = 0.0
if redundancy.shape[0] > 0:
max_redundancy = redundancy[index].max()
candidate = _MmrCandidate(
node=nodes[candidate_id],
similarity=similarity[index][0],
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self._lambda_mult_complement * max_redundancy,
)
self._candidates.append(candidate)

if candidate.score >= self._best_score:
self._best_score = candidate.score
self._best_id = candidate.node.id

# Add the new embeddings to the candidate set.
self._candidate_embeddings = np.vstack(
(
self._candidate_embeddings,
new_embeddings,
# Add the new embeddings to the candidate set.
self._candidate_embeddings = np.vstack(
(
self._candidate_embeddings,
new_embeddings,
)
)
)

tracker.select_and_traverse(self._next())
26 changes: 11 additions & 15 deletions packages/graph-retriever/src/graph_retriever/strategies/scored.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import dataclasses
import heapq
from collections.abc import Callable, Iterable
from collections.abc import Callable

from typing_extensions import override

from graph_retriever.strategies.base import Strategy
from graph_retriever.strategies.base import NodeTracker, Strategy
from graph_retriever.types import Node


Expand All @@ -13,33 +13,29 @@ def __init__(self, score: float, node: Node) -> None:
self.score = score
self.node = node

def __lt__(self, other) -> bool:
def __lt__(self, other: "_ScoredNode") -> bool:
return other.score < self.score


@dataclasses.dataclass
class Scored(Strategy):
"""Strategy selecing nodes using a scoring function."""
"""Strategy selecting nodes using a scoring function."""

scorer: Callable[[Node], float]
_nodes: list[_ScoredNode] = dataclasses.field(default_factory=list)

per_iteration_limit: int | None = None
per_iteration_limit: int = 2

@override
def discover_nodes(self, nodes: dict[str, Node]) -> None:
def iteration(self, nodes: dict[str, Node], tracker: NodeTracker) -> None:
for node in nodes.values():
heapq.heappush(self._nodes, _ScoredNode(self.scorer(node), node))

@override
def select_nodes(self, *, limit: int) -> Iterable[Node]:
if self.per_iteration_limit and self.per_iteration_limit < limit:
limit = self.per_iteration_limit

selected = []
for _x in range(limit):
selected = {}
for _x in range(self.per_iteration_limit):
if not self._nodes:
break

selected.append(heapq.heappop(self._nodes).node)
return selected
node = heapq.heappop(self._nodes).node
selected[node.id] = node
tracker.select_and_traverse(selected)
Loading