Skip to content

Commit

Permalink
Vector Store, full separation codec / vectorstore (#113)
Browse files Browse the repository at this point in the history
* centralize codec's id/vector encoding; add multi-ids encoding

* move all indexing, _id and similarity management into codecs so that it cleanly passes through the coded layer all the time

* trading encode_id[s] for encode_query

* them docstrings
  • Loading branch information
hemidactylus authored Feb 18, 2025
1 parent f6e73fa commit 3715b62
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 131 deletions.
169 changes: 136 additions & 33 deletions libs/astradb/langchain_astradb/utils/vector_store_codecs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Classes to handle encoding of documents on DB for the Vector Store.."""
"""Classes to handle encoding of Documents on DB for the Vector Store.."""

from __future__ import annotations

import logging
import warnings
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Iterable

from langchain_core.documents import Document
from typing_extensions import override
Expand All @@ -16,14 +16,44 @@
)
FLATTEN_CONFLICT_MSG = "Cannot flatten metadata: field name overlap for '{field}'."

STANDARD_INDEXING_OPTIONS_DEFAULT = {"allow": ["metadata"]}

logger = logging.getLogger(__name__)


def _default_decode_vector(astra_doc: dict[str, Any]) -> list[float] | None:
"""Extract the embedding vector from an Astra DB document."""
return astra_doc.get("$vector")


def _default_metadata_key_to_field_identifier(md_key: str) -> str:
"""Rewrite a metadata key name to its full path in the 'default' encoding.
The input `md_key` is an "abstract" metadata key, while the return value
identifies its actual full-path location on an Astra DB document encoded in the
'default' way (i.e. with a nested `metadata` dictionary).
"""
return f"metadata.{md_key}"


def _flat_metadata_key_to_field_identifier(md_key: str) -> str:
"""Rewrite a metadata key name to its full path in the 'flat' encoding.
The input `md_key` is an "abstract" metadata key, while the return value
identifies its actual full-path location on an Astra DB document encoded in the
'flat' way (i.e. metadata fields appearing at top-level in the Astra DB document).
"""
return md_key


def _default_encode_filter(filter_dict: dict[str, Any]) -> dict[str, Any]:
"""Encode an "abstract" metadata condition for the 'default' encoding.
The input can express a query clause on metadata and uses just the metadata field
names, possibly connected/nested through AND and ORs. The output makes key names
into their full path-identifiers (e.g. "metadata.xyz") according to the 'default'
encoding scheme for Astra DB documents.
"""
metadata_filter = {}
for k, v in filter_dict.items():
# Key in this dict starting with $ are supposedly operators and as such
Expand All @@ -37,32 +67,48 @@ def _default_encode_filter(filter_dict: dict[str, Any]) -> dict[str, Any]:
# assume each list item can be fed back to this function
metadata_filter[k] = _default_encode_filter(v) # type: ignore[assignment]
else:
metadata_filter[f"metadata.{k}"] = v
metadata_filter[_default_metadata_key_to_field_identifier(k)] = v

return metadata_filter


def _default_encode_id(filter_id: str) -> dict[str, Any]:
def _astra_generic_encode_id(filter_id: str) -> dict[str, Any]:
"""Encoding of a single Document ID as a query clause for an Astra DB document."""
return {"_id": filter_id}


def _default_encode_vector_sort(vector: list[float]) -> dict[str, Any]:
def _astra_generic_encode_ids(filter_ids: list[str]) -> dict[str, Any]:
"""Encoding of Document IDs as a query clause for an Astra DB document.
This function picks the right, and most concise, expression based on the
multiplicity of the provided IDs.
"""
if len(filter_ids) == 1:
return _astra_generic_encode_id(filter_ids[0])
return {"_id": {"$in": filter_ids}}


def _astra_generic_encode_vector_sort(vector: list[float]) -> dict[str, Any]:
"""Encoding of a vector-based sort as a query clause for an Astra DB document."""
return {"$vector": vector}


class _AstraDBVectorStoreDocumentCodec(ABC):
"""A document codec for the Astra DB vector store.
"""A Document codec for the Astra DB vector store.
The document codec contains the information for consistent interaction
Document codecs hold the logic consistent interaction
with documents as stored on the Astra DB collection.
In this context, 'Document' (capital D) refers to the LangChain class,
while 'Astra DB document' refers to the JSON-like object stored on DB.
Implementations of this class must:
- define how to encode/decode documents consistently to and from
Astra DB collections. The two operations must, so to speak, combine
to the identity on both sides (except for the quirks of their signatures).
- provide the adequate projection dictionaries for running find
operations on Astra DB, with and without the field containing the vector.
- encode IDs to the `_id` field on Astra DB.
- encode Document IDs to the right field on Astra DB ("_id" for Collections).
- define the name of the field storing the textual content of the Document.
- define whether embeddings are computed server-side (with $vectorize) or not.
"""
Expand Down Expand Up @@ -132,17 +178,72 @@ def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]:
"""

@abstractmethod
def encode_id(self, filter_id: str) -> dict[str, Any]:
"""Encode an ID as a filter for use in Astra DB queries.
def metadata_key_to_field_identifier(self, md_key: str) -> str:
"""Express an 'abstract' metadata key as a full Data API field identifier."""

@property
@abstractmethod
def default_collection_indexing_policy(self) -> dict[str, list[str]]:
"""Provide the default indexing policy if the collection must be created."""

def get_id(self, astra_document: dict[str, Any]) -> str:
"""Return the ID of an encoded document (= a raw JSON read from DB)."""
return astra_document["_id"]

def get_similarity(self, astra_document: dict[str, Any]) -> float:
"""Return the similarity of an encoded document (= a raw JSON read from DB).
This method assumes its argument comes from a suitable vector search.
"""
return astra_document["$similarity"]

def encode_query(
self,
*,
ids: Iterable[str] | None = None,
filter_dict: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Prepare an encoded query according to the Astra DB document encoding.
The method optionally accepts both IDs and metadata filters. The two,
if passed together, are automatically combined with an AND operation.
In other words, if passing both IDs and a metadata filtering clause,
the resulting query would return Astra DB documents matching the metadata
clause AND having an ID among those provided to this method. If, instead,
an OR is required, one should run two separate queries and subsequently merge
the result (taking care of avoiding duplcates).
Args:
filter_id: the ID value to filter on.
ids: an iterable over Document IDs. If provided, the resulting Astra DB
query dictionary expresses the requirement that returning documents
have an ID among those provided here. Passing an empty iterable,
or None, results in a query with no conditions on the IDs at all.
filter_dict: a metadata filtering part. If provided, if must refer to
metadata keys by their bare name (such as `{"key": 123}`).
This filter can combine nested conditions with "$or"/"$and" connectors,
for example:
- `{"tag": "a"}`
- `{"$or": [{"tag": "a"}, "label": "b"]}`
- `{"$and": [{"tag": {"$in": ["a", "z"]}}, "label": "b"]}`
Returns:
an filter clause for use in Astra DB's find queries.
a query dictionary ready to be used in an Astra DB find operation on
a collection.
"""
clauses: list[dict[str, Any]] = []
_ids_list = list(ids or [])
if _ids_list:
clauses.append(_astra_generic_encode_ids(_ids_list))
if filter_dict:
clauses.append(self.encode_filter(filter_dict))

if clauses:
if len(clauses) > 1:
return {"$and": clauses}
return clauses[0]
return {}

@abstractmethod
def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
"""Encode a vector as a sort to use for Astra DB queries.
Expand All @@ -152,6 +253,7 @@ def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
Returns:
an order clause for use in Astra DB's find queries.
"""
return _astra_generic_encode_vector_sort(vector)


class _DefaultVSDocumentCodec(_AstraDBVectorStoreDocumentCodec):
Expand Down Expand Up @@ -226,12 +328,12 @@ def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]:
return _default_encode_filter(filter_dict)

@override
def encode_id(self, filter_id: str) -> dict[str, Any]:
return _default_encode_id(filter_id)
def metadata_key_to_field_identifier(self, md_key: str) -> str:
return _default_metadata_key_to_field_identifier(md_key)

@override
def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
return _default_encode_vector_sort(vector)
@property
def default_collection_indexing_policy(self) -> dict[str, list[str]]:
return STANDARD_INDEXING_OPTIONS_DEFAULT


class _DefaultVectorizeVSDocumentCodec(_AstraDBVectorStoreDocumentCodec):
Expand Down Expand Up @@ -308,13 +410,13 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None:
def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]:
return _default_encode_filter(filter_dict)

@override
def encode_id(self, filter_id: str) -> dict[str, Any]:
return _default_encode_id(filter_id)
@property
def default_collection_indexing_policy(self) -> dict[str, list[str]]:
return STANDARD_INDEXING_OPTIONS_DEFAULT

@override
def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
return _default_encode_vector_sort(vector)
def metadata_key_to_field_identifier(self, md_key: str) -> str:
return _default_metadata_key_to_field_identifier(md_key)


class _FlatVSDocumentCodec(_AstraDBVectorStoreDocumentCodec):
Expand Down Expand Up @@ -396,13 +498,13 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None:
def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]:
return filter_dict

@override
def encode_id(self, filter_id: str) -> dict[str, Any]:
return _default_encode_id(filter_id)
@property
def default_collection_indexing_policy(self) -> dict[str, list[str]]:
return {"deny": [self.content_field]}

@override
def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
return _default_encode_vector_sort(vector)
def metadata_key_to_field_identifier(self, md_key: str) -> str:
return _flat_metadata_key_to_field_identifier(md_key)


class _FlatVectorizeVSDocumentCodec(_AstraDBVectorStoreDocumentCodec):
Expand Down Expand Up @@ -477,10 +579,11 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None:
def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]:
return filter_dict

@override
def encode_id(self, filter_id: str) -> dict[str, Any]:
return _default_encode_id(filter_id)
@property
def default_collection_indexing_policy(self) -> dict[str, list[str]]:
# $vectorize cannot be de-indexed explicitly (the API manages it entirely).
return {}

@override
def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
return _default_encode_vector_sort(vector)
def metadata_key_to_field_identifier(self, md_key: str) -> str:
return _flat_metadata_key_to_field_identifier(md_key)
Loading

0 comments on commit 3715b62

Please sign in to comment.