Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add [a]delete_by_metadata and [a]update_metadata methods to vector store #89

Merged
merged 3 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,62 @@ async def adelete(
)
return True

def delete_by_metadata_filter(
self,
filter: dict[str, Any], # noqa: A002
) -> int | None:
"""Delete all documents matching a certain metadata filtering condition.

This operation does not use the vector embeddings in any way, it simply
removes all documents whose metadata match the provided condition.
Use with caution: passing an empty filter dictionary results in
completely emptying the vector store.

Args:
filter: Filter on the metadata to apply.

Returns:
An number expressing the amount of deleted documents.
This will be None if a `{}` metadata filter condition is passed,
implying emptying the store entirely.
"""
self.astra_env.ensure_db_setup()
metadata_parameter = self.filter_to_query(filter)
del_result = self.astra_env.collection.delete_many(
filter=metadata_parameter,
)
if del_result.deleted_count is not None and del_result.deleted_count >= 0:
return del_result.deleted_count
return None

async def adelete_by_metadata_filter(
self,
filter: dict[str, Any], # noqa: A002
) -> int | None:
"""Delete all documents matching a certain metadata filtering condition.

This operation does not use the vector embeddings in any way, it simply
removes all documents whose metadata match the provided condition.
Use with caution: passing an empty filter dictionary results in
completely emptying the vector store.

Args:
filter: Filter on the metadata to apply.

Returns:
An number expressing the amount of deleted documents.
This will be None if a `{}` metadata filter condition is passed,
implying emptying the store entirely.
"""
await self.astra_env.aensure_db_setup()
metadata_parameter = self.filter_to_query(filter)
del_result = await self.astra_env.async_collection.delete_many(
filter=metadata_parameter,
)
if del_result.deleted_count is not None and del_result.deleted_count >= 0:
return del_result.deleted_count
return None

def delete_collection(self) -> None:
"""Completely delete the collection from the database.

Expand Down Expand Up @@ -1166,6 +1222,110 @@ async def _replace_document(
raise ValueError(msg)
return inserted_ids

def update_metadata(
self,
id_to_metadata: dict[str, dict],
*,
overwrite_concurrency: int | None = None,
) -> int:
"""Add/overwrite the metadata of existing documents.

For each document to update, the new metadata dictionary is added
to the existing metadata, overwriting individual keys that existed already.

Args:
id_to_metadata: map from the Document IDs to modify to the
new metadata for updating. Keys in this dictionary that
do not correspond to an existing document will be silently ignored.
The values of this map are metadata dictionaries for updating
the documents. Any pre-existing metadata will be merged with
these entries, which take precedence on a key-by-key basis.
overwrite_concurrency: number of threads to process the updates
Defaults to the vector-store overall setting if not provided.

Returns:
the number of documents successfully updated (i.e. found to exist,
since even an update with `{}` as the new metadata counts as successful.)
"""
self.astra_env.ensure_db_setup()

_max_workers = overwrite_concurrency or self.bulk_insert_overwrite_concurrency
with ThreadPoolExecutor(
max_workers=_max_workers,
) as executor:

def _update_document(
id_md_pair: tuple[str, dict],
) -> UpdateResult:
document_id, update_metadata = id_md_pair
encoded_metadata = self.filter_to_query(update_metadata)
return self.astra_env.collection.update_one(
{"_id": document_id},
{"$set": encoded_metadata},
)

update_results = list(
executor.map(
_update_document,
id_to_metadata.items(),
)
)

return sum(u_res.update_info["n"] for u_res in update_results)

async def aupdate_metadata(
self,
id_to_metadata: dict[str, dict],
*,
overwrite_concurrency: int | None = None,
) -> int:
"""Add/overwrite the metadata of existing documents.

For each document to update, the new metadata dictionary is added
to the existing metadata, overwriting individual keys that existed already.

Args:
id_to_metadata: map from the Document IDs to modify to the
new metadata for updating. Keys in this dictionary that
do not correspond to an existing document will be silently ignored.
The values of this map are metadata dictionaries for updating
the documents. Any pre-existing metadata will be merged with
these entries, which take precedence on a key-by-key basis.
overwrite_concurrency: number of threads to process the updates
Defaults to the vector-store overall setting if not provided.

Returns:
the number of documents successfully updated (i.e. found to exist,
since even an update with `{}` as the new metadata counts as successful.)
"""
await self.astra_env.aensure_db_setup()

sem = asyncio.Semaphore(
overwrite_concurrency or self.bulk_insert_overwrite_concurrency,
)

_async_collection = self.astra_env.async_collection

async def _update_document(
id_md_pair: tuple[str, dict],
) -> UpdateResult:
document_id, update_metadata = id_md_pair
encoded_metadata = self.filter_to_query(update_metadata)
async with sem:
return await _async_collection.update_one(
{"_id": document_id},
{"$set": encoded_metadata},
)

tasks = [
asyncio.create_task(_update_document(id_md_pair))
for id_md_pair in id_to_metadata.items()
]

update_results = await asyncio.gather(*tasks, return_exceptions=False)

return sum(u_res.update_info["n"] for u_res in update_results)

@override
def similarity_search(
self,
Expand Down
160 changes: 160 additions & 0 deletions libs/astradb/tests/integration_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,166 @@ async def test_astradb_vectorstore_massive_insert_replace_async(
for doc, _, doc_id in full_results:
assert doc.page_content == expected_text_by_id[doc_id]

def test_astradb_vectorstore_delete_by_metadata_sync(
self,
vector_store_d2: AstraDBVectorStore,
) -> None:
"""Testing delete_by_metadata_filter."""
full_size = 400
# one in ... will be deleted
deletee_ratio = 3

documents = [
Document(
page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0}
)
for doc_i in range(full_size)
]

inserted_ids0 = vector_store_d2.add_documents(documents)
assert len(inserted_ids0) == len(documents)

d_result0 = vector_store_d2.delete_by_metadata_filter({"deletee": True})
assert d_result0 is not None
assert d_result0 == len([doc for doc in documents if doc.metadata["deletee"]])

d_result1 = vector_store_d2.delete_by_metadata_filter({})
assert d_result1 is None
assert len(vector_store_d2.similarity_search("[1,1]", k=1)) == 0

async def test_astradb_vectorstore_delete_by_metadata_async(
self,
vector_store_d2: AstraDBVectorStore,
) -> None:
"""Testing delete_by_metadata_filter, async version."""
full_size = 400
# one in ... will be deleted
deletee_ratio = 3

documents = [
Document(
page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0}
)
for doc_i in range(full_size)
]

inserted_ids0 = await vector_store_d2.aadd_documents(documents)
assert len(inserted_ids0) == len(documents)

d_result0 = await vector_store_d2.adelete_by_metadata_filter({"deletee": True})
assert d_result0 is not None
assert d_result0 == len([doc for doc in documents if doc.metadata["deletee"]])

d_result1 = await vector_store_d2.adelete_by_metadata_filter({})
assert d_result1 is None
assert len(await vector_store_d2.asimilarity_search("[1,1]", k=1)) == 0

def test_astradb_vectorstore_update_metadata_sync(
self,
vector_store_d2: AstraDBVectorStore,
) -> None:
"""Testing update_metadata."""
# this should not exceed the max number of hits from ANN search
full_size = 20
# one in ... will be updated
updatee_ratio = 2
# set this to lower than full_size // updatee_ratio to test everything.
update_concurrency = 7

def doc_sorter(doc: Document) -> str:
return doc.id or ""

orig_documents0 = [
Document(
page_content="[1,1]",
metadata={
"to_update": doc_i % updatee_ratio == 0,
"inert_field": "I",
"updatee_field": "0",
},
id=f"um_doc_{doc_i}",
)
for doc_i in range(full_size)
]
orig_documents = sorted(orig_documents0, key=doc_sorter)

inserted_ids0 = vector_store_d2.add_documents(orig_documents)
assert len(inserted_ids0) == len(orig_documents)

update_map = {
f"um_doc_{doc_i}": {"updatee_field": "1", "to_update": False}
for doc_i in range(full_size)
if doc_i % updatee_ratio == 0
}
u_result0 = vector_store_d2.update_metadata(
update_map,
overwrite_concurrency=update_concurrency,
)
assert u_result0 == len(update_map)

all_documents = sorted(
vector_store_d2.similarity_search("[1,1]", k=full_size),
key=doc_sorter,
)
assert len(all_documents) == len(orig_documents)
for doc, orig_doc in zip(all_documents, orig_documents):
assert doc.id == orig_doc.id
if doc.id in update_map:
assert doc.metadata == orig_doc.metadata | update_map[doc.id]

async def test_astradb_vectorstore_update_metadata_async(
self,
vector_store_d2: AstraDBVectorStore,
) -> None:
"""Testing update_metadata, async version."""
# this should not exceed the max number of hits from ANN search
full_size = 20
# one in ... will be updated
updatee_ratio = 2
# set this to lower than full_size // updatee_ratio to test everything.
update_concurrency = 7

def doc_sorter(doc: Document) -> str:
return doc.id or ""

orig_documents0 = [
Document(
page_content="[1,1]",
metadata={
"to_update": doc_i % updatee_ratio == 0,
"inert_field": "I",
"updatee_field": "0",
},
id=f"um_doc_{doc_i}",
)
for doc_i in range(full_size)
]
orig_documents = sorted(orig_documents0, key=doc_sorter)

inserted_ids0 = await vector_store_d2.aadd_documents(orig_documents)
assert len(inserted_ids0) == len(orig_documents)

update_map = {
f"um_doc_{doc_i}": {"updatee_field": "1", "to_update": False}
for doc_i in range(full_size)
if doc_i % updatee_ratio == 0
}
u_result0 = await vector_store_d2.aupdate_metadata(
update_map,
overwrite_concurrency=update_concurrency,
)
assert u_result0 == len(update_map)

all_documents = sorted(
await vector_store_d2.asimilarity_search("[1,1]", k=full_size),
key=doc_sorter,
)
assert len(all_documents) == len(orig_documents)
for doc, orig_doc in zip(all_documents, orig_documents):
assert doc.id == orig_doc.id
if doc.id in update_map:
assert doc.metadata == orig_doc.metadata | update_map[doc.id]

def test_astradb_vectorstore_mmr_sync(
self,
vector_store_d2: AstraDBVectorStore,
Expand Down
Loading
Loading