From a597eec5ea08c64878438acf9c32f2745405d740 Mon Sep 17 00:00:00 2001 From: Louis Beaumont Date: Mon, 13 Feb 2023 15:14:47 +0000 Subject: [PATCH] 0.5.9: feat: abstract away database --- .gitignore | 2 + Makefile | 5 +-- search/api.py | 91 +++++++++++-------------------------- search/db.py | 60 +++++++++++++++++++++++++ search/pinecone_db.py | 101 ++++++++++++++++++++++++++++++++++++++++++ search/test_main.py | 64 +++++++++++++++----------- search/weaviate_db.py | 55 +++++++++++++++++++++++ setup.py | 2 +- 8 files changed, 285 insertions(+), 95 deletions(-) create mode 100644 search/db.py create mode 100644 search/pinecone_db.py create mode 100644 search/weaviate_db.py diff --git a/.gitignore b/.gitignore index 82937851..905b1236 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ # python garbage __pycache__ .pytest_cache +*.egg-info +build # python virtual environment env diff --git a/Makefile b/Makefile index ba778efa..58bb7796 100644 --- a/Makefile +++ b/Makefile @@ -6,9 +6,8 @@ LOCAL_PORT="8000" install: ## [DEVELOPMENT] Install the API dependencies virtualenv env; \ - source env/bin/activate; \ - pip install -r requirements.txt; \ - pip install -r requirements-test.txt + . env/bin/activate; \ + pip install .[all] @echo "Done, run '\033[0;31msource env/bin/activate\033[0m' to activate the virtual environment" run: ## [DEVELOPMENT] Run the API diff --git a/search/api.py b/search/api.py index 6b3da448..cb6f7e7b 100644 --- a/search/api.py +++ b/search/api.py @@ -1,3 +1,4 @@ +import asyncio import hashlib from multiprocessing.pool import ThreadPool import time @@ -15,11 +16,10 @@ SearchRequest, ) from fastapi.responses import JSONResponse -import pinecone import urllib.parse import numpy as np +from search.pinecone_db import Pinecone from search.settings import Settings, get_settings -from .utils import BatchGenerator, too_big_rows import openai import sentry_sdk @@ -29,7 +29,7 @@ from tenacity.after import after_log from tenacity.stop import stop_after_attempt import requests -from typing import Tuple +from typing import List, Tuple settings = get_settings() @@ -129,18 +129,18 @@ async def firebase_auth(request: Request, call_next) -> Tuple[str, str]: allow_headers=["*"], ) -pinecone.init( - api_key=settings.pinecone_api_key, environment=settings.pinecone_environment +vector_database = Pinecone( + api_key=settings.pinecone_api_key, + environment=settings.pinecone_environment, + index_name=settings.pinecone_index, ) openai.api_key = settings.openai_api_key openai.organization = settings.openai_organization -pinecone_index = settings.pinecone_index -index = pinecone.Index(pinecone_index, pool_threads=8) @app.on_event("startup") -def startup_event(): - result = index.fetch(ids=["foo"]) # TODO: container startup check +async def startup_event(): + result = await vector_database.fetch(ids=["foo"]) # TODO: container startup check if result: logger.info("Properly connected to Pinecone") else: @@ -188,61 +188,26 @@ def embed( after=after_log(logger, logging.ERROR), stop=stop_after_attempt(3), ) -def upload_embeddings_to_vector_database(df: DataFrame, namespace: str): - # TODO: batch size should depend on payload - df_batcher = BatchGenerator(UPLOAD_BATCH_SIZE) - logger.info("Uploading vectors namespace..") - start_time_upload = time.time() - batches = [batch_df for batch_df in df_batcher(df)] - - def _insert(batch_df: DataFrame): - bigs = too_big_rows(batch_df) - if len(bigs) > 0: - logger.info(f"Ignoring {len(bigs)} rows that are too big") - # remove rows that are too big, in the right axis - batch_df = batch_df.drop(bigs, axis=0) - response = index.upsert( - vectors=zip( - # pinecone needs to have the document path url encoded - batch_df.id.apply(urllib.parse.quote).tolist(), - batch_df.embedding, - [ - { - "data": data, - } - for data in batch_df.data - ], - ), - namespace=namespace, - async_req=True, - ) - logger.info(f"Uploaded {len(batch_df)} vectors") - return response - - [response.get() for response in map(_insert, batches)] - - logger.info(f"Uploaded in {time.time() - start_time_upload} seconds") - def get_namespace(request: Request, vault_id: str) -> str: return f"{request.scope.get('uid')}/{vault_id}" @app.get("/v1/{vault_id}/clear") -def clear( +async def clear( request: Request, vault_id: str, _: Settings = Depends(get_settings), ): namespace = get_namespace(request, vault_id) - index.delete(delete_all=True, namespace=namespace) + await vector_database.clear(namespace=namespace) logger.info("Cleared index") return JSONResponse(status_code=200, content={"status": "success"}) @app.post("/v1/{vault_id}") -def add( +async def add( request: Request, vault_id: str, request_body: AddRequest, @@ -280,9 +245,7 @@ def add( return JSONResponse(status_code=200, content={"status": "success"}) # add column "hash" based on "data" - df.hash = df.data.apply( - lambda x: hashlib.sha256(x.encode()).hexdigest() - ) + df.hash = df.data.apply(lambda x: hashlib.sha256(x.encode()).hexdigest()) df_length = len(df) existing_hashes = [] @@ -295,20 +258,19 @@ def add( # in the index metadata ids_to_fetch = df.id.apply(urllib.parse.quote).tolist() # split in chunks of n because fetch has a limit of size + # TODO: abstract away batching n = 200 ids_to_fetch = [ids_to_fetch[i : i + n] for i in range(0, len(ids_to_fetch), n)] logger.info(f"Fetching {len(ids_to_fetch)} chunks of {n} ids") - def _fetch(ids): + async def _fetch(ids) -> List[dict]: try: - return index.fetch(ids=ids, namespace=namespace) + return await vector_database.fetch(ids=ids, namespace=namespace) except Exception as e: logger.error(f"Error fetching {ids}: {e}", exc_info=True) raise e - with ThreadPool(len(ids_to_fetch)) as pool: - existing_documents = pool.map(lambda n: _fetch(n), ids_to_fetch) - # flatten vectors.values() + existing_documents = await asyncio.gather(*[_fetch(ids) for ids in ids_to_fetch]) flat_existing_documents = itertools.chain.from_iterable( [doc.vectors.values() for doc in existing_documents] ) @@ -328,9 +290,7 @@ def _fetch(ids): ] else: # generate ids using hash + time - df.id = df.hash.apply( - lambda x: f"{x}-{int(time.time() * 1000)}" - ) + df.id = df.hash.apply(lambda x: f"{x}-{int(time.time() * 1000)}") diff = df_length - len(df) @@ -364,7 +324,7 @@ def _fetch(ids): # # merge s column into a single column , ignore index # df.embedding = s.apply(lambda x: x.tolist(), axis=1) # TODO: problem is that pinecone doesn't support this large of an input - upload_embeddings_to_vector_database(df, namespace) + await vector_database.update(df, namespace, batch_size=UPLOAD_BATCH_SIZE) logger.info(f"Indexed & uploaded {len(df)} sentences") end_time = time.time() @@ -379,8 +339,9 @@ def _fetch(ids): }, ) + @app.delete("/v1/{vault_id}") -def delete( +async def delete( request: Request, vault_id: str, request_body: DeleteRequest, @@ -395,14 +356,14 @@ def delete( ids = request_body.ids logger.info(f"Deleting {len(ids)} documents") quoted_ids = [urllib.parse.quote(id) for id in ids] - index.delete(ids=quoted_ids, namespace=namespace) + await vector_database.delete(ids=quoted_ids, namespace=namespace) logger.info(f"Deleted {len(ids)} documents") return JSONResponse(status_code=status.HTTP_200_OK, content={"status": "success"}) @app.post("/v1/{vault_id}/search") -def semantic_search( +async def semantic_search( request: Request, vault_id: str, request_body: SearchRequest, @@ -422,16 +383,14 @@ def semantic_search( logger.info(f"Query {request_body.query} created embedding, querying index") - query_response = index.query( + query_response = await vector_database.search( top_k=top_k, - include_values=True, - include_metadata=True, vector=query_embedding, namespace=namespace, ) similarities = [] - for match in query_response.matches: + for match in query_response: logger.debug(f"Match id: {match.id}") decoded_id = urllib.parse.unquote(match.id) logger.debug(f"Decoded id: {decoded_id}") diff --git a/search/db.py b/search/db.py new file mode 100644 index 00000000..6228df12 --- /dev/null +++ b/search/db.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +from typing import Coroutine, List, Optional, Tuple + +from pandas import DataFrame + +# TODO: make this less Pinecone specific +class VectorDatabase(ABC): + """ + Base class for all vector databases + """ + + @abstractmethod + async def fetch( + self, ids: List[str], namespace: Optional[str] = None + ) -> List[dict]: + """ + :param ids: list of ids + :param namespace: namespace + :return: list of vectors + """ + raise NotImplementedError + + @abstractmethod + async def update( + self, + df: DataFrame, + namespace: Optional[str] = None, + batch_size: Optional[int] = 100, + ) -> Coroutine: + """ + :param vectors: list of vectors + :param namespace: namespace + """ + raise NotImplementedError + + @abstractmethod + async def delete(self, ids: List[str], namespace: Optional[str] = None) -> None: + """ + :param ids: list of ids + """ + raise NotImplementedError + + @abstractmethod + async def search( + self, vector: List[float], top_k: Optional[int], namespace: Optional[str] = None + ) -> List[dict]: + """ + :param vector: vector + :param top_k: top k + :param namespace: namespace + :return: list of vectors + """ + raise NotImplementedError + + @abstractmethod + async def clear(self, namespace: Optional[str] = None) -> None: + """ + :param namespace: namespace + """ + raise NotImplementedError diff --git a/search/pinecone_db.py b/search/pinecone_db.py new file mode 100644 index 00000000..d62db566 --- /dev/null +++ b/search/pinecone_db.py @@ -0,0 +1,101 @@ +from typing import Coroutine, List, Optional +from pandas import DataFrame +from search.utils import BatchGenerator, too_big_rows +from search.db import VectorDatabase +import urllib.parse +import pinecone + + +class Pinecone(VectorDatabase): + def __init__( + self, + api_key: str, + environment: str, + index_name: str, + default_namespace: Optional[str] = None, + ): + """ + :param api_key: api key + :param pinecone_environment: pinecone environment + """ + pinecone.init( + api_key=api_key, + pinecone_environment=environment, + ) + self.index = pinecone.Index(index_name, pool_threads=8) + self.default_namespace = default_namespace + + async def fetch( + self, ids: List[str], namespace: Optional[str] = None + ) -> List[dict]: + """ + :param ids: list of ids + :return: list of vectors + """ + return self.index.fetch(ids, namespace=namespace or self.default_namespace) + + async def update( + self, + df: DataFrame, + namespace: Optional[str] = None, + batch_size: Optional[int] = 100, + ) -> Coroutine: + """ + :param vectors: list of vectors + :param namespace: namespace + """ + df_batcher = BatchGenerator(batch_size) + batches = [batch_df for batch_df in df_batcher(df)] + + def _insert(batch_df: DataFrame): + bigs = too_big_rows(batch_df) + # remove rows that are too big, in the right axis + batch_df = batch_df.drop(bigs, axis=0) + response = self.index.upsert( + vectors=zip( + # pinecone needs to have the document path url encoded + batch_df.id.apply(urllib.parse.quote).tolist(), + batch_df.embedding, + [ + { + "data": data, + } + for data in batch_df.data + ], + ), + namespace=namespace or self.default_namespace, + async_req=True, + ) + return response + + [response.get() for response in map(_insert, batches)] + + async def delete(self, ids: List[str], namespace: Optional[str] = None) -> None: + """ + :param ids: list of ids + """ + self.index.delete(ids, namespace=namespace or self.default_namespace) + + async def search( + self, vector: List[float], top_k: Optional[int], namespace: Optional[str] = None + ) -> List[dict]: + """ + :param vector: vector + :param top_k: top k + :param namespace: namespace + :return: list of vectors + """ + return self.index.query( + vector, + top_k=top_k, + namespace=namespace or self.default_namespace, + include_values=True, + include_metadata=True, + ).matches + + async def clear(self, namespace: Optional[str]) -> None: + """ + :param namespace: namespace + """ + # HACK: stupid hack "%" because pinecone doc is incorrect and u need to pass ids & clear_all + self.index.delete(ids=["%"], clear_all=True, namespace=namespace or self.default_namespace) diff --git a/search/test_main.py b/search/test_main.py index 3cc1d9d1..f7a6bf40 100644 --- a/search/test_main.py +++ b/search/test_main.py @@ -1,26 +1,33 @@ from fastapi.testclient import TestClient +from httpx import AsyncClient from pandas import DataFrame -from .api import app, embed, upload_embeddings_to_vector_database, index, no_batch_embed +import pytest +from search.pinecone_db import Pinecone + +from search.settings import get_settings +from .api import app, embed, no_batch_embed import pandas as pd import math from random import randint import numpy as np - -def test_clear(): - with TestClient(app=app) as client: - response = client.get( - "/v1/dev/clear", +@pytest.mark.asyncio +async def test_clear(): + async with AsyncClient(app=app, base_url="http://localhost:8000") as client: + response = await client.get( + "/v1/dev/clear", ) assert response.status_code == 200 assert response.json().get("status", "") == "success" -def test_semantic_search(): - with TestClient(app=app) as client: - response = client.post("/v1/dev/search", json={"query": "bob"}) +@pytest.mark.asyncio +async def test_semantic_search(): + async with AsyncClient(app=app, base_url="http://localhost:8000") as client: + response = await client.post("/v1/dev/search", json={"query": "bob"}) assert response.status_code == 200 -def test_refresh_small_documents(): +@pytest.mark.asyncio +async def test_refresh_small_documents(): df = pd.DataFrame( [ "".join( @@ -33,8 +40,8 @@ def test_refresh_small_documents(): ], columns=["text"], ) - with TestClient(app=app) as client: - response = client.post( + async with AsyncClient(app=app, base_url="http://localhost:8000") as client: + response = await client.post( "/v1/dev", json={ "documents": [ @@ -58,7 +65,8 @@ def test_embed_large_text(): data = no_batch_embed("".join("a" * 10_000)) assert len(data) == 1536 -def test_upload(): +@pytest.mark.asyncio +async def test_upload(): data = embed(["hello world", "hello world"]) df = DataFrame( [ @@ -75,18 +83,24 @@ def test_upload(): "id", ], ) - upload_embeddings_to_vector_database(df, "unit_test_test_upload") - results = index.query( + settings = get_settings() + vector_database = Pinecone( + api_key=settings.pinecone_api_key, + environment=settings.pinecone_environment, + index_name=settings.pinecone_index, + ) + await vector_database.update(df, "unit_test_test_upload") + + results = await vector_database.search( data[0]["embedding"], top_k=2, - include_values=True, namespace="unit_test_test_upload", ) - assert results.matches[0]["id"] == "1" - assert results.matches[1]["id"] == "0" - + assert results[0]["id"] == "1" + assert results[1]["id"] == "0" -def test_ignore_document_that_didnt_change(): +@pytest.mark.asyncio +async def test_ignore_document_that_didnt_change(): df = pd.DataFrame( [ ("".join( @@ -99,11 +113,11 @@ def test_ignore_document_that_didnt_change(): ], columns=["text", "id"], ) - with TestClient(app=app) as client: - response = client.get( + async with AsyncClient(app=app, base_url="http://localhost:8000") as client: + response = await client.get( "/v1/dev/clear", ) - response = client.post( + response = await client.post( "/v1/dev", json={ "documents": [ @@ -119,8 +133,8 @@ def test_ignore_document_that_didnt_change(): ids = response.json().get("inserted_ids", []) # add to df df["id"] = ids - with TestClient(app=app) as client: - response = client.post( + async with AsyncClient(app=app, base_url="http://localhost:8000") as client: + response = await client.post( "/v1/dev", json={ "documents": [ diff --git a/search/weaviate_db.py b/search/weaviate_db.py new file mode 100644 index 00000000..1e0f2366 --- /dev/null +++ b/search/weaviate_db.py @@ -0,0 +1,55 @@ +from typing import Coroutine, List, Optional +from pandas import DataFrame +from search.db import VectorDatabase + + +class Weaviate(VectorDatabase): + def __init__( + self, + ): + """ + """ + pass + async def fetch( + self, ids: List[str], namespace: Optional[str] = None + ) -> List[dict]: + """ + :param ids: list of ids + :return: list of vectors + """ + raise NotImplementedError + + async def update( + self, + df: DataFrame, + namespace: Optional[str] = None, + batch_size: Optional[int] = 100, + ) -> Coroutine: + """ + :param vectors: list of vectors + :param namespace: namespace + """ + raise NotImplementedError + + async def delete(self, ids: List[str], namespace: Optional[str] = None) -> None: + """ + :param ids: list of ids + """ + raise NotImplementedError + + async def search( + self, vector: List[float], top_k: Optional[int], namespace: Optional[str] = None + ) -> List[dict]: + """ + :param vector: vector + :param top_k: top k + :param namespace: namespace + :return: list of vectors + """ + raise NotImplementedError + + async def clear(self, namespace: Optional[str]) -> None: + """ + :param namespace: namespace + """ + raise NotImplementedError diff --git a/setup.py b/setup.py index bf78341d..df1a71ae 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ name="embedbase", packages=find_packages(), include_package_data=True, - version="0.5.8", + version="0.5.9", description="", install_requires=install_requires, extras_require=extras_require,