Skip to content

Commit

Permalink
split search_with_embedding method and added sync version
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Oct 24, 2024
1 parent 9dcc670 commit c927282
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 81 deletions.
6 changes: 2 additions & 4 deletions libs/astradb/langchain_astradb/graph_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,11 +1218,9 @@ async def _get_adjacent(
)
)

results: list[
tuple[list[float], list[tuple[Document, list[float]]]]
] = await asyncio.gather(*tasks)
results: list[list[tuple[Document, list[float]]]] = await asyncio.gather(*tasks)

for _, result in results:
for result in results:
for doc, embedding in result:
if doc.id is not None:
retrieved_docs[doc.id] = doc
Expand Down
192 changes: 156 additions & 36 deletions libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,64 +1787,144 @@ async def asimilarity_search_with_score_id_by_vector(
filter=filter,
)

async def asimilarity_search_with_embedding(
def similarity_search_with_embedding_by_vector(
self,
query_or_embedding: str | list[float],
embedding: list[float],
k: int = 4,
filter: dict[str, Any] | None = None, # noqa: A002
) -> tuple[list[float], list[tuple[Document, list[float]]]]:
"""Returns the embedded query and docs most similar to query.
) -> list[tuple[Document, list[float]]]:
"""Return docs most similar to the query embedding vector with
their document embedding vectors.
Args:
query_or_embedding: Query or Embedding to look up
documents similar to.
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
(query_embedding, List of (Document, embedding) most similar to the query).
(The query embedding vector, The list of (Document, embedding),
the most similar to the query vector.).
"""

Check failure on line 1807 in libs/astradb/langchain_astradb/vectorstores.py

View workflow job for this annotation

GitHub Actions / cd libs/astradb / make lint #3.9

Ruff (D205)

langchain_astradb/vectorstores.py:1796:9: D205 1 blank line required between summary line and description
await self.astra_env.aensure_db_setup()
sort = self.document_codec.encode_vector_sort(vector=embedding)
_, doc_emb_list = self._similarity_search_with_embedding_by_sort(
sort=sort, k=k, filter=filter
)
return doc_emb_list

sort: dict[str, Any] = {}
include_sort_vector: bool = False
query_embedding: list[float] = []
async def asimilarity_search_with_embedding_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: dict[str, Any] | None = None, # noqa: A002
) -> list[tuple[Document, list[float]]]:
"""Return docs most similar to the query embedding vector with
their document embedding vectors.
if isinstance(query_or_embedding, str):
query: str = query_or_embedding
if self.document_codec.server_side_embeddings:
include_sort_vector = True
sort = {"$vectorize": query}
else:
query_embedding = self._get_safe_embedding().embed_query(text=query)
sort = self.document_codec.encode_vector_sort(vector=query_embedding)
elif isinstance(query_or_embedding, list):
query_embedding = query_or_embedding
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
(The query embedding vector, The list of (Document, embedding),
the most similar to the query vector.).
"""

Check failure on line 1831 in libs/astradb/langchain_astradb/vectorstores.py

View workflow job for this annotation

GitHub Actions / cd libs/astradb / make lint #3.9

Ruff (D205)

langchain_astradb/vectorstores.py:1820:9: D205 1 blank line required between summary line and description
sort = self.document_codec.encode_vector_sort(vector=embedding)
_, doc_emb_list = await self._asimilarity_search_with_embedding_by_sort(
sort=sort, k=k, filter=filter
)
return doc_emb_list

def similarity_search_with_embedding(
self,
query: str,
k: int = 4,
filter: dict[str, Any] | None = None, # noqa: A002
) -> tuple[list[float], list[tuple[Document, list[float]]]]:
"""Return the embedded query vector and docs most similar
to the query embedding vector with their document embedding
vectors.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
(The query embedding vector, The list of (Document, embedding),
the most similar to the query vector.).
"""

Check failure on line 1856 in libs/astradb/langchain_astradb/vectorstores.py

View workflow job for this annotation

GitHub Actions / cd libs/astradb / make lint #3.9

Ruff (D202)

langchain_astradb/vectorstores.py:1844:9: D202 No blank lines allowed after function docstring (found 1)

Check failure on line 1856 in libs/astradb/langchain_astradb/vectorstores.py

View workflow job for this annotation

GitHub Actions / cd libs/astradb / make lint #3.9

Ruff (D205)

langchain_astradb/vectorstores.py:1844:9: D205 1 blank line required between summary line and description

if self.document_codec.server_side_embeddings:
sort = {"$vectorize": query}
else:
query_embedding = self._get_safe_embedding().embed_query(text=query)
# shortcut return if query isn't needed.
if k == 0:
return (query_embedding, [])
sort = self.document_codec.encode_vector_sort(vector=query_embedding)

return self._similarity_search_with_embedding_by_sort(
sort=sort, k=k, filter=filter
)

async def asimilarity_search_with_embedding(
self,
query: str,
k: int = 4,
filter: dict[str, Any] | None = None, # noqa: A002
) -> tuple[list[float], list[tuple[Document, list[float]]]]:
"""Return the embedded query vector and docs most similar
to the query embedding vector with their document embedding
vectors.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
(The query embedding vector, The list of (Document, embedding),
the most similar to the query vector.).
"""

Check failure on line 1889 in libs/astradb/langchain_astradb/vectorstores.py

View workflow job for this annotation

GitHub Actions / cd libs/astradb / make lint #3.9

Ruff (D202)

langchain_astradb/vectorstores.py:1877:9: D202 No blank lines allowed after function docstring (found 1)

Check failure on line 1889 in libs/astradb/langchain_astradb/vectorstores.py

View workflow job for this annotation

GitHub Actions / cd libs/astradb / make lint #3.9

Ruff (D205)

langchain_astradb/vectorstores.py:1877:9: D205 1 blank line required between summary line and description

if self.document_codec.server_side_embeddings:
sort = {"$vectorize": query}
else:
msg = (
"Expected a 'str' or a 'list[float]' for 'query_or_embedding', ",
f"got {type(query_or_embedding)} instead.",
)
raise TypeError(msg)
query_embedding = self._get_safe_embedding().embed_query(text=query)
# shortcut return if query isn't needed.
if k == 0:
return (query_embedding, [])
sort = self.document_codec.encode_vector_sort(vector=query_embedding)

# shortcut return if query isn't needed.
if k == 0 and len(query_embedding) > 0:
return (query_embedding, [])
return await self._asimilarity_search_with_embedding_by_sort(
sort=sort, k=k, filter=filter
)

async def _asimilarity_search_with_embedding_by_sort(
self,
sort: dict[str, Any],
k: int = 4,
filter: dict[str, Any] | None = None, # noqa: A002
) -> tuple[list[float], list[tuple[Document, list[float]]]]:
"""Run ANN search with a provided sort clause.
Returns:
(query_embedding, List of (Document, embedding) most similar to the query).
"""
await self.astra_env.aensure_db_setup()
async_cursor = self.astra_env.async_collection.find(
filter=self.filter_to_query(filter),
projection=self.document_codec.full_projection,
limit=k,
include_sort_vector=include_sort_vector,
include_sort_vector=True,
sort=sort,
)
if include_sort_vector:
sort_vector = await async_cursor.get_sort_vector()
if sort_vector is None:
msg = "Unable to retrieve the server-side embedding of the query."
raise ValueError(msg)
query_embedding = sort_vector
sort_vector = await async_cursor.get_sort_vector()
if sort_vector is None:
msg = "Unable to retrieve the server-side embedding of the query."
raise ValueError(msg)
query_embedding = sort_vector

return (
query_embedding,
Expand All @@ -1861,6 +1941,46 @@ async def asimilarity_search_with_embedding(
],
)

def _similarity_search_with_embedding_by_sort(
self,
sort: dict[str, Any],
k: int = 4,
filter: dict[str, Any] | None = None, # noqa: A002
) -> tuple[list[float], list[tuple[Document, list[float]]]]:
"""Run ANN search with a provided sort clause.
Returns:
(query_embedding, List of (Document, embedding) most similar to the query).
"""
self.astra_env.ensure_db_setup()
cursor = self.astra_env.collection.find(
filter=self.filter_to_query(filter),
projection=self.document_codec.full_projection,
limit=k,
include_sort_vector=True,
sort=sort,
)
sort_vector = cursor.get_sort_vector()
if sort_vector is None:
msg = "Unable to retrieve the server-side embedding of the query."
raise ValueError(msg)
query_embedding = sort_vector

return (
query_embedding,
[
(doc, emb)
for (doc, emb) in (
(
self.document_codec.decode(hit),
self.document_codec.decode_vector(hit),
)
for hit in cursor
)
if doc is not None and emb is not None
],
)

async def _asimilarity_search_with_score_id_by_sort(
self,
sort: dict[str, Any],
Expand Down
Loading

0 comments on commit c927282

Please sign in to comment.