Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated the strategy design to separate traversal from node selection #152

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 53 additions & 36 deletions packages/graph-retriever/src/graph_retriever/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,41 @@
from graph_retriever.types import Node


class NodeTracker:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a property to compute the remaining unselected.

"""Helper class for tracking traversal progress."""

def __init__(self, select_k: int, max_depth: int| None) -> None:
self._select_k: int = select_k
self._max_depth: int | None = max_depth
self._visited_nodes: set[int] = set()
self.to_traverse: dict[str, Node] = {}
self.selected: dict[str, Node] = {}

@property
def remaining(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe num_remaining or something like that to emphasize it's a number? And maybe mentioning selected (num_unselected?)

"""The remaining number of nodes to be selected"""
return max(self._select_k - len(self.selected), 0)

def select(self, nodes: dict[str, Node]) -> None:
"""Select nodes to be included in the result set."""
self.selected.update(nodes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should select return the number of new nodes (or the number of total nodes)?

I guess most use cases would be better off with len(tracker.selected). Although, we could have a @property def num_selected(self) -> int?


def traverse(self, nodes: dict[str, Node]) -> int:
"""Select nodes to be included in the next traversal."""
for id, node in nodes.items():
if id in self._visited_nodes:
continue
if self._max_depth is None or node.depth <= self._max_depth:
continue
self.to_traverse[id] = node
return len(self.to_traverse)

def select_and_traverse(self, nodes: dict[str, Node]) -> None:
"""Select nodes to be included in the result set and the next traversal."""
self.select(nodes)
self.traverse(nodes)


@dataclasses.dataclass(kw_only=True)
class Strategy(abc.ABC):
"""
Expand All @@ -21,31 +56,36 @@ class Strategy(abc.ABC):

Parameters
----------
k :
Maximum number of nodes to retrieve during traversal.
select_k :
Maximum number of nodes to select and return during traversal.
start_k :
Number of documents to fetch via similarity for starting the traversal.
Number of nodes to fetch via similarity for starting the traversal.
Added to any initial roots provided to the traversal.
adjacent_k :
Number of documents to fetch for each outgoing edge.
Number of nodes to fetch for each outgoing edge.
max_traverse :
Maximum number of nodes to traverse outgoing edges from before returning.
If `None`, there is no limit.
max_depth :
Maximum traversal depth. If `None`, there is no limit.
"""

k: int = 5
select_k: int = 5
start_k: int = 4
adjacent_k: int = 10
max_traverse: int | None = None
max_depth: int | None = None

_query_embedding: list[float] = dataclasses.field(default_factory=list)

@abc.abstractmethod
def discover_nodes(self, nodes: dict[str, Node]) -> None:
def iteration(self, *, nodes: dict[str, Node], tracker: NodeTracker) -> None:
"""
Add discovered nodes to the strategy.
Process the newly discovered nodes on each iteration.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be worth elaborating on this -- the newly discovered seems important. For instance, this means that if a strategy wants to potentially visit them in a future iteration (such as MMR) it needs to remember these nodes and select/traverse them later.


This method updates the strategy's state with nodes discovered during
the traversal process.
This method should call `traverse` and/or `select` as appropriate
to update the nodes that need to be traversed in this iteration or
selected at the end of the retrieval, respectively.

Parameters
----------
Expand All @@ -54,43 +94,20 @@ def discover_nodes(self, nodes: dict[str, Node]) -> None:
"""
...

@abc.abstractmethod
def select_nodes(self, *, limit: int) -> Iterable[Node]:
"""
Select discovered nodes to visit in the next iteration.

This method determines which nodes will be traversed next. If it returns
an empty list, traversal ends even if fewer than `k` nodes have been selected.

Parameters
----------
limit :
Maximum number of nodes to select.

Returns
-------
:
Selected nodes for the next iteration. Traversal ends if this is empty.
"""
...

def finalize_nodes(self, nodes: Iterable[Node]) -> Iterable[Node]:
def finalize_nodes(self, selected: Iterable[Node]) -> Iterable[Node]:
"""
Finalize the selected nodes.

This method is called before returning the final set of nodes.

Parameters
----------
nodes :
Nodes selected for finalization.

Returns
-------
:
Finalized nodes.
"""
return nodes
# Take the first `self.k` selected items.
# Strategies may override finalize to perform reranking if needed.
return list(selected)[: self.select_k]

@staticmethod
def build(
Expand Down
15 changes: 3 additions & 12 deletions packages/graph-retriever/src/graph_retriever/strategies/eager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Provide eager (breadth-first) traversal strategy."""

import dataclasses
from collections.abc import Iterable

from typing_extensions import override

from graph_retriever.strategies.base import Strategy
from graph_retriever.strategies.base import NodeTracker, Strategy
from graph_retriever.types import Node


Expand All @@ -32,14 +31,6 @@ class Eager(Strategy):
Maximum traversal depth. If `None`, there is no limit.
"""

_nodes: list[Node] = dataclasses.field(default_factory=list)

@override
def discover_nodes(self, nodes: dict[str, Node]) -> None:
self._nodes.extend(nodes.values())

@override
def select_nodes(self, *, limit: int) -> Iterable[Node]:
nodes = self._nodes[:limit]
self._nodes = []
return nodes
def iteration(self, nodes: dict[str, Node], tracker: NodeTracker) -> None:
tracker.select_and_traverse(nodes)
109 changes: 50 additions & 59 deletions packages/graph-retriever/src/graph_retriever/strategies/mmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numpy.typing import NDArray
from typing_extensions import override

from graph_retriever.strategies.base import Strategy
from graph_retriever.strategies.base import NodeTracker, Strategy
from graph_retriever.types import Node
from graph_retriever.utils.math import cosine_similarity

Expand Down Expand Up @@ -51,7 +51,7 @@ class Mmr(Strategy):

Parameters
----------
k :
select_k :
Maximum number of nodes to retrieve during traversal.
start_k :
Number of documents to fetch via similarity for starting the traversal.
Expand Down Expand Up @@ -111,7 +111,7 @@ def _selected_embeddings(self) -> NDArray[np.float32]:
NDArray[np.float32]
(N, dim) ndarray with a row for each selected node.
"""
return np.ndarray((self.k, self._dimensions), dtype=np.float32)
return np.ndarray((self.select_k, self._dimensions), dtype=np.float32)

@cached_property
def _candidate_embeddings(self) -> NDArray[np.float32]:
Expand Down Expand Up @@ -203,8 +203,7 @@ def _pop_candidate(

return candidate, embedding

@override
def select_nodes(self, *, limit: int) -> Iterable[Node]:
def _next(self) -> dict[str, Node]:
"""
Select and pop the best item being considered.

Expand All @@ -214,10 +213,8 @@ def select_nodes(self, *, limit: int) -> Iterable[Node]:
-------
A tuple containing the ID of the best item.
"""
if limit == 0:
return []
if self._best_id is None or self._best_score < self.min_mmr_score:
return []
return {}

# Get the selection and remove from candidates.
selected_id = self._best_id
Expand Down Expand Up @@ -250,61 +247,55 @@ def select_nodes(self, *, limit: int) -> Iterable[Node]:
self._best_score = candidate.score
self._best_id = candidate.node.id

return [selected_node]
return {selected_node.id: selected_node}

@override
def discover_nodes(self, nodes: dict[str, Node]) -> None:
def iteration(self, nodes: dict[str, Node], tracker: NodeTracker) -> None:
"""Add candidates to the consideration set."""
# Determine the keys to actually include.
# These are the candidates that aren't already selected
# or under consideration.

include_ids_set = set(nodes.keys())
include_ids_set.difference_update(self._selected_ids)
include_ids_set.difference_update(self._candidate_id_to_index.keys())
include_ids = list(include_ids_set)

# Now, build up a matrix of the remaining candidate embeddings.
# And add them to the
new_embeddings: NDArray[np.float32] = np.ndarray(
(
len(include_ids),
self._dimensions,
if len(nodes) > 0:
# Build up a matrix of the remaining candidate embeddings.
# And add them to the candidate set
new_embeddings: NDArray[np.float32] = np.ndarray(
(
len(nodes),
self._dimensions,
)
)
)
offset = self._candidate_embeddings.shape[0]
for index, candidate_id in enumerate(include_ids):
self._candidate_id_to_index[candidate_id] = offset + index
new_embeddings[index] = nodes[candidate_id].embedding

# Compute the similarity to the query.
similarity = cosine_similarity(new_embeddings, self._nd_query_embedding)

# Compute the distance metrics of all of pairs in the selected set with
# the new candidates.
redundancy = cosine_similarity(
new_embeddings, self._already_selected_embeddings()
)
for index, candidate_id in enumerate(include_ids):
max_redundancy = 0.0
if redundancy.shape[0] > 0:
max_redundancy = redundancy[index].max()
candidate = _MmrCandidate(
node=nodes[candidate_id],
similarity=similarity[index][0],
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self._lambda_mult_complement * max_redundancy,
offset = self._candidate_embeddings.shape[0]
for index, candidate_id in enumerate(nodes.keys()):
self._candidate_id_to_index[candidate_id] = offset + index
new_embeddings[index] = nodes[candidate_id].embedding

# Compute the similarity to the query.
similarity = cosine_similarity(new_embeddings, self._nd_query_embedding)

# Compute the distance metrics of all of pairs in the selected set with
# the new candidates.
redundancy = cosine_similarity(
new_embeddings, self._already_selected_embeddings()
)
self._candidates.append(candidate)

if candidate.score >= self._best_score:
self._best_score = candidate.score
self._best_id = candidate.node.id
for index, candidate_id in enumerate(nodes.keys()):
max_redundancy = 0.0
if redundancy.shape[0] > 0:
max_redundancy = redundancy[index].max()
candidate = _MmrCandidate(
node=nodes[candidate_id],
similarity=similarity[index][0],
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self._lambda_mult_complement * max_redundancy,
)
self._candidates.append(candidate)

if candidate.score >= self._best_score:
self._best_score = candidate.score
self._best_id = candidate.node.id

# Add the new embeddings to the candidate set.
self._candidate_embeddings = np.vstack(
(
self._candidate_embeddings,
new_embeddings,
# Add the new embeddings to the candidate set.
self._candidate_embeddings = np.vstack(
(
self._candidate_embeddings,
new_embeddings,
)
)
)

tracker.select_and_traverse(self._next())
26 changes: 11 additions & 15 deletions packages/graph-retriever/src/graph_retriever/strategies/scored.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import dataclasses
import heapq
from collections.abc import Callable, Iterable
from collections.abc import Callable

from typing_extensions import override

from graph_retriever.strategies.base import Strategy
from graph_retriever.strategies.base import NodeTracker, Strategy
from graph_retriever.types import Node


Expand All @@ -13,33 +13,29 @@ def __init__(self, score: float, node: Node) -> None:
self.score = score
self.node = node

def __lt__(self, other) -> bool:
def __lt__(self, other: "_ScoredNode") -> bool:
return other.score < self.score


@dataclasses.dataclass
class Scored(Strategy):
"""Strategy selecing nodes using a scoring function."""
"""Strategy selecting nodes using a scoring function."""

scorer: Callable[[Node], float]
_nodes: list[_ScoredNode] = dataclasses.field(default_factory=list)

per_iteration_limit: int | None = None
per_iteration_limit: int = 2

@override
def discover_nodes(self, nodes: dict[str, Node]) -> None:
def iteration(self, nodes: dict[str, Node], tracker: NodeTracker) -> None:
for node in nodes.values():
heapq.heappush(self._nodes, _ScoredNode(self.scorer(node), node))

@override
def select_nodes(self, *, limit: int) -> Iterable[Node]:
if self.per_iteration_limit and self.per_iteration_limit < limit:
limit = self.per_iteration_limit

selected = []
for _x in range(limit):
selected = {}
for _x in range(self.per_iteration_limit):
if not self._nodes:
break

selected.append(heapq.heappop(self._nodes).node)
return selected
node = heapq.heappop(self._nodes).node
selected[node.id] = node
tracker.select_and_traverse(selected)
Loading
Loading