diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index b2accd8..bd9915d 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -1218,11 +1218,9 @@ async def _get_adjacent( ) ) - results: list[ - tuple[list[float], list[tuple[Document, list[float]]]] - ] = await asyncio.gather(*tasks) + results: list[list[tuple[Document, list[float]]]] = await asyncio.gather(*tasks) - for _, result in results: + for result in results: for doc, embedding in result: if doc.id is not None: retrieved_docs[doc.id] = doc diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index 6fbcc40..724b89c 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -1787,64 +1787,144 @@ async def asimilarity_search_with_score_id_by_vector( filter=filter, ) - async def asimilarity_search_with_embedding( + def similarity_search_with_embedding_by_vector( self, - query_or_embedding: str | list[float], + embedding: list[float], k: int = 4, filter: dict[str, Any] | None = None, # noqa: A002 - ) -> tuple[list[float], list[tuple[Document, list[float]]]]: - """Returns the embedded query and docs most similar to query. + ) -> list[tuple[Document, list[float]]]: + """Return docs most similar to the query embedding vector with + their document embedding vectors. Args: - query_or_embedding: Query or Embedding to look up - documents similar to. + embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter: Filter on the metadata to apply. Returns: - (query_embedding, List of (Document, embedding) most similar to the query). + (The query embedding vector, The list of (Document, embedding), + the most similar to the query vector.). """ - await self.astra_env.aensure_db_setup() + sort = self.document_codec.encode_vector_sort(vector=embedding) + _, doc_emb_list = self._similarity_search_with_embedding_by_sort( + sort=sort, k=k, filter=filter + ) + return doc_emb_list - sort: dict[str, Any] = {} - include_sort_vector: bool = False - query_embedding: list[float] = [] + async def asimilarity_search_with_embedding_by_vector( + self, + embedding: list[float], + k: int = 4, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> list[tuple[Document, list[float]]]: + """Return docs most similar to the query embedding vector with + their document embedding vectors. - if isinstance(query_or_embedding, str): - query: str = query_or_embedding - if self.document_codec.server_side_embeddings: - include_sort_vector = True - sort = {"$vectorize": query} - else: - query_embedding = self._get_safe_embedding().embed_query(text=query) - sort = self.document_codec.encode_vector_sort(vector=query_embedding) - elif isinstance(query_or_embedding, list): - query_embedding = query_or_embedding + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + + Returns: + (The query embedding vector, The list of (Document, embedding), + the most similar to the query vector.). + """ + sort = self.document_codec.encode_vector_sort(vector=embedding) + _, doc_emb_list = await self._asimilarity_search_with_embedding_by_sort( + sort=sort, k=k, filter=filter + ) + return doc_emb_list + + def similarity_search_with_embedding( + self, + query: str, + k: int = 4, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> tuple[list[float], list[tuple[Document, list[float]]]]: + """Return the embedded query vector and docs most similar + to the query embedding vector with their document embedding + vectors. + + Args: + query: Query to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + + Returns: + (The query embedding vector, The list of (Document, embedding), + the most similar to the query vector.). + """ + + if self.document_codec.server_side_embeddings: + sort = {"$vectorize": query} + else: + query_embedding = self._get_safe_embedding().embed_query(text=query) + # shortcut return if query isn't needed. + if k == 0: + return (query_embedding, []) sort = self.document_codec.encode_vector_sort(vector=query_embedding) + + return self._similarity_search_with_embedding_by_sort( + sort=sort, k=k, filter=filter + ) + + async def asimilarity_search_with_embedding( + self, + query: str, + k: int = 4, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> tuple[list[float], list[tuple[Document, list[float]]]]: + """Return the embedded query vector and docs most similar + to the query embedding vector with their document embedding + vectors. + + Args: + query: Query to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + + Returns: + (The query embedding vector, The list of (Document, embedding), + the most similar to the query vector.). + """ + + if self.document_codec.server_side_embeddings: + sort = {"$vectorize": query} else: - msg = ( - "Expected a 'str' or a 'list[float]' for 'query_or_embedding', ", - f"got {type(query_or_embedding)} instead.", - ) - raise TypeError(msg) + query_embedding = self._get_safe_embedding().embed_query(text=query) + # shortcut return if query isn't needed. + if k == 0: + return (query_embedding, []) + sort = self.document_codec.encode_vector_sort(vector=query_embedding) - # shortcut return if query isn't needed. - if k == 0 and len(query_embedding) > 0: - return (query_embedding, []) + return await self._asimilarity_search_with_embedding_by_sort( + sort=sort, k=k, filter=filter + ) + async def _asimilarity_search_with_embedding_by_sort( + self, + sort: dict[str, Any], + k: int = 4, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> tuple[list[float], list[tuple[Document, list[float]]]]: + """Run ANN search with a provided sort clause. + + Returns: + (query_embedding, List of (Document, embedding) most similar to the query). + """ + await self.astra_env.aensure_db_setup() async_cursor = self.astra_env.async_collection.find( filter=self.filter_to_query(filter), projection=self.document_codec.full_projection, limit=k, - include_sort_vector=include_sort_vector, + include_sort_vector=True, sort=sort, ) - if include_sort_vector: - sort_vector = await async_cursor.get_sort_vector() - if sort_vector is None: - msg = "Unable to retrieve the server-side embedding of the query." - raise ValueError(msg) - query_embedding = sort_vector + sort_vector = await async_cursor.get_sort_vector() + if sort_vector is None: + msg = "Unable to retrieve the server-side embedding of the query." + raise ValueError(msg) + query_embedding = sort_vector return ( query_embedding, @@ -1861,6 +1941,46 @@ async def asimilarity_search_with_embedding( ], ) + def _similarity_search_with_embedding_by_sort( + self, + sort: dict[str, Any], + k: int = 4, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> tuple[list[float], list[tuple[Document, list[float]]]]: + """Run ANN search with a provided sort clause. + + Returns: + (query_embedding, List of (Document, embedding) most similar to the query). + """ + self.astra_env.ensure_db_setup() + cursor = self.astra_env.collection.find( + filter=self.filter_to_query(filter), + projection=self.document_codec.full_projection, + limit=k, + include_sort_vector=True, + sort=sort, + ) + sort_vector = cursor.get_sort_vector() + if sort_vector is None: + msg = "Unable to retrieve the server-side embedding of the query." + raise ValueError(msg) + query_embedding = sort_vector + + return ( + query_embedding, + [ + (doc, emb) + for (doc, emb) in ( + ( + self.document_codec.decode(hit), + self.document_codec.decode_vector(hit), + ) + for hit in cursor + ) + if doc is not None and emb is not None + ], + ) + async def _asimilarity_search_with_score_id_by_sort( self, sort: dict[str, Any], diff --git a/libs/astradb/tests/integration_tests/test_vectorstore.py b/libs/astradb/tests/integration_tests/test_vectorstore.py index 3475509..785deb8 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstore.py +++ b/libs/astradb/tests/integration_tests/test_vectorstore.py @@ -30,6 +30,11 @@ from .conftest import AstraDBCredentials +def assert_list_of_numeric(value: list[float]) -> None: + assert isinstance(value, list) + assert all(isinstance(item, (float, int)) for item in value) + + @pytest.fixture def metadata_documents() -> list[Document]: """Documents for metadata and id tests""" @@ -1410,30 +1415,17 @@ async def test_astradb_vectorstore_similarity_scale_async( assert sco_far < EUCLIDEAN_MIN_SIM_UNIT_VECTORS + MATCH_EPSILON @pytest.mark.parametrize( - ("is_vectorize", "vector_store", "query_or_embedding_mode"), + "vector_store", [ - (False, "vector_store_d2", "query"), - (False, "vector_store_d2", "embedding"), - (False, "vector_store_d2", "other"), - (True, "vector_store_vz", "query"), - (True, "vector_store_vz", "embedding"), - (True, "vector_store_vz", "other"), - ], - ids=[ - "nonvectorize_store_with_query", - "nonvectorize_store_with_embedding", - "nonvectorize_store_with_other", - "vectorize_store_with_query", - "vectorize_store_with_embedding", - "vectorize_store_with_other", + "vector_store_d2", + "vector_store_vz", ], + ids=["nonvectorize_store", "vectorize_store"], ) async def test_astradb_vectorstore_asimilarity_search_with_embedding( self, *, - is_vectorize: bool, vector_store: str, - query_or_embedding_mode: str, metadata_documents: list[Document], request: pytest.FixtureRequest, ) -> None: @@ -1443,31 +1435,77 @@ async def test_astradb_vectorstore_asimilarity_search_with_embedding( vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) await vstore.aadd_documents(metadata_documents) - def assert_list_of_numeric(value: list[float]) -> None: - assert isinstance(value, list) - assert all(isinstance(item, (float, int)) for item in value) + query_embedding, results = await vstore.asimilarity_search_with_embedding( + query="[-1,2]" + ) - if query_or_embedding_mode == "query": - query_embedding, results = await vstore.asimilarity_search_with_embedding( - query_or_embedding="[-1,2]" - ) - elif query_or_embedding_mode == "embedding": - vector_dimensions = 1536 if is_vectorize else 2 - query_embedding, results = await vstore.asimilarity_search_with_embedding( - query_or_embedding=[ - random.uniform(0.0, 1.0) # noqa: S311 - for _ in range(vector_dimensions) - ] - ) - else: - with pytest.raises( - TypeError, - match=r"Expected a", - ): - await vstore.asimilarity_search_with_embedding( - query_or_embedding={"test": "error"} # type: ignore # noqa: PGH003 - ) - return + assert_list_of_numeric(query_embedding) + assert isinstance(results, list) + assert len(results) > 0 + (doc, embedding) = results[0] + assert isinstance(doc, Document) + assert_list_of_numeric(embedding) + + @pytest.mark.parametrize( + ("is_vectorize", "vector_store"), + [ + (False, "vector_store_d2"), + (True, "vector_store_vz"), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + async def test_astradb_vectorstore_asimilarity_search_with_embedding_by_vector( + self, + *, + is_vectorize: bool, + vector_store: str, + metadata_documents: list[Document], + request: pytest.FixtureRequest, + ) -> None: + """asimilarity_search_with_embedding_by_vector is used as the building + block for other components (like AstraDBGraphVectorStore). + """ + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + await vstore.aadd_documents(metadata_documents) + + vector_dimensions = 1536 if is_vectorize else 2 + results = await vstore.asimilarity_search_with_embedding_by_vector( + embedding=[ + random.uniform(0.0, 1.0) # noqa: S311 + for _ in range(vector_dimensions) + ] + ) + + assert isinstance(results, list) + assert len(results) > 0 + (doc, embedding) = results[0] + assert isinstance(doc, Document) + assert_list_of_numeric(embedding) + + @pytest.mark.parametrize( + "vector_store", + [ + "vector_store_d2", + "vector_store_vz", + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + def test_astradb_vectorstore_similarity_search_with_embedding( + self, + *, + vector_store: str, + metadata_documents: list[Document], + request: pytest.FixtureRequest, + ) -> None: + """similarity_search_with_embedding is used as the building + block for other components (like AstraDBGraphVectorStore). + """ + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore.add_documents(metadata_documents) + + query_embedding, results = vstore.similarity_search_with_embedding( + query="[-1,2]" + ) assert_list_of_numeric(query_embedding) assert isinstance(results, list) @@ -1476,6 +1514,42 @@ def assert_list_of_numeric(value: list[float]) -> None: assert isinstance(doc, Document) assert_list_of_numeric(embedding) + @pytest.mark.parametrize( + ("is_vectorize", "vector_store"), + [ + (False, "vector_store_d2"), + (True, "vector_store_vz"), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + def test_astradb_vectorstore_similarity_search_with_embedding_by_vector( + self, + *, + is_vectorize: bool, + vector_store: str, + metadata_documents: list[Document], + request: pytest.FixtureRequest, + ) -> None: + """similarity_search_with_embedding_by_vector is used as the building + block for other components (like AstraDBGraphVectorStore). + """ + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore.aadd_documents(metadata_documents) + + vector_dimensions = 1536 if is_vectorize else 2 + results = vstore.similarity_search_with_embedding_by_vector( + embedding=[ + random.uniform(0.0, 1.0) # noqa: S311 + for _ in range(vector_dimensions) + ] + ) + + assert isinstance(results, list) + assert len(results) > 0 + (doc, embedding) = results[0] + assert isinstance(doc, Document) + assert_list_of_numeric(embedding) + @pytest.mark.parametrize( "vector_store", [