Skip to content

Commit

Permalink
0.5.9: feat: abstract away database
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Feb 13, 2023
1 parent bd724ec commit a597eec
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 95 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# python garbage
__pycache__
.pytest_cache
*.egg-info
build

# python virtual environment
env
Expand Down
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ LOCAL_PORT="8000"

install: ## [DEVELOPMENT] Install the API dependencies
virtualenv env; \
source env/bin/activate; \
pip install -r requirements.txt; \
pip install -r requirements-test.txt
. env/bin/activate; \
pip install .[all]
@echo "Done, run '\033[0;31msource env/bin/activate\033[0m' to activate the virtual environment"

run: ## [DEVELOPMENT] Run the API
Expand Down
91 changes: 25 additions & 66 deletions search/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import hashlib
from multiprocessing.pool import ThreadPool
import time
Expand All @@ -15,11 +16,10 @@
SearchRequest,
)
from fastapi.responses import JSONResponse
import pinecone
import urllib.parse
import numpy as np
from search.pinecone_db import Pinecone
from search.settings import Settings, get_settings
from .utils import BatchGenerator, too_big_rows
import openai
import sentry_sdk

Expand All @@ -29,7 +29,7 @@
from tenacity.after import after_log
from tenacity.stop import stop_after_attempt
import requests
from typing import Tuple
from typing import List, Tuple


settings = get_settings()
Expand Down Expand Up @@ -129,18 +129,18 @@ async def firebase_auth(request: Request, call_next) -> Tuple[str, str]:
allow_headers=["*"],
)

pinecone.init(
api_key=settings.pinecone_api_key, environment=settings.pinecone_environment
vector_database = Pinecone(
api_key=settings.pinecone_api_key,
environment=settings.pinecone_environment,
index_name=settings.pinecone_index,
)
openai.api_key = settings.openai_api_key
openai.organization = settings.openai_organization
pinecone_index = settings.pinecone_index
index = pinecone.Index(pinecone_index, pool_threads=8)


@app.on_event("startup")
def startup_event():
result = index.fetch(ids=["foo"]) # TODO: container startup check
async def startup_event():
result = await vector_database.fetch(ids=["foo"]) # TODO: container startup check
if result:
logger.info("Properly connected to Pinecone")
else:
Expand Down Expand Up @@ -188,61 +188,26 @@ def embed(
after=after_log(logger, logging.ERROR),
stop=stop_after_attempt(3),
)
def upload_embeddings_to_vector_database(df: DataFrame, namespace: str):
# TODO: batch size should depend on payload
df_batcher = BatchGenerator(UPLOAD_BATCH_SIZE)
logger.info("Uploading vectors namespace..")
start_time_upload = time.time()
batches = [batch_df for batch_df in df_batcher(df)]

def _insert(batch_df: DataFrame):
bigs = too_big_rows(batch_df)
if len(bigs) > 0:
logger.info(f"Ignoring {len(bigs)} rows that are too big")
# remove rows that are too big, in the right axis
batch_df = batch_df.drop(bigs, axis=0)
response = index.upsert(
vectors=zip(
# pinecone needs to have the document path url encoded
batch_df.id.apply(urllib.parse.quote).tolist(),
batch_df.embedding,
[
{
"data": data,
}
for data in batch_df.data
],
),
namespace=namespace,
async_req=True,
)
logger.info(f"Uploaded {len(batch_df)} vectors")
return response

[response.get() for response in map(_insert, batches)]

logger.info(f"Uploaded in {time.time() - start_time_upload} seconds")


def get_namespace(request: Request, vault_id: str) -> str:
return f"{request.scope.get('uid')}/{vault_id}"


@app.get("/v1/{vault_id}/clear")
def clear(
async def clear(
request: Request,
vault_id: str,
_: Settings = Depends(get_settings),
):
namespace = get_namespace(request, vault_id)

index.delete(delete_all=True, namespace=namespace)
await vector_database.clear(namespace=namespace)
logger.info("Cleared index")
return JSONResponse(status_code=200, content={"status": "success"})


@app.post("/v1/{vault_id}")
def add(
async def add(
request: Request,
vault_id: str,
request_body: AddRequest,
Expand Down Expand Up @@ -280,9 +245,7 @@ def add(
return JSONResponse(status_code=200, content={"status": "success"})

# add column "hash" based on "data"
df.hash = df.data.apply(
lambda x: hashlib.sha256(x.encode()).hexdigest()
)
df.hash = df.data.apply(lambda x: hashlib.sha256(x.encode()).hexdigest())

df_length = len(df)
existing_hashes = []
Expand All @@ -295,20 +258,19 @@ def add(
# in the index metadata
ids_to_fetch = df.id.apply(urllib.parse.quote).tolist()
# split in chunks of n because fetch has a limit of size
# TODO: abstract away batching
n = 200
ids_to_fetch = [ids_to_fetch[i : i + n] for i in range(0, len(ids_to_fetch), n)]
logger.info(f"Fetching {len(ids_to_fetch)} chunks of {n} ids")

def _fetch(ids):
async def _fetch(ids) -> List[dict]:
try:
return index.fetch(ids=ids, namespace=namespace)
return await vector_database.fetch(ids=ids, namespace=namespace)
except Exception as e:
logger.error(f"Error fetching {ids}: {e}", exc_info=True)
raise e

with ThreadPool(len(ids_to_fetch)) as pool:
existing_documents = pool.map(lambda n: _fetch(n), ids_to_fetch)
# flatten vectors.values()
existing_documents = await asyncio.gather(*[_fetch(ids) for ids in ids_to_fetch])
flat_existing_documents = itertools.chain.from_iterable(
[doc.vectors.values() for doc in existing_documents]
)
Expand All @@ -328,9 +290,7 @@ def _fetch(ids):
]
else:
# generate ids using hash + time
df.id = df.hash.apply(
lambda x: f"{x}-{int(time.time() * 1000)}"
)
df.id = df.hash.apply(lambda x: f"{x}-{int(time.time() * 1000)}")

diff = df_length - len(df)

Expand Down Expand Up @@ -364,7 +324,7 @@ def _fetch(ids):
# # merge s column into a single column , ignore index
# df.embedding = s.apply(lambda x: x.tolist(), axis=1)
# TODO: problem is that pinecone doesn't support this large of an input
upload_embeddings_to_vector_database(df, namespace)
await vector_database.update(df, namespace, batch_size=UPLOAD_BATCH_SIZE)

logger.info(f"Indexed & uploaded {len(df)} sentences")
end_time = time.time()
Expand All @@ -379,8 +339,9 @@ def _fetch(ids):
},
)


@app.delete("/v1/{vault_id}")
def delete(
async def delete(
request: Request,
vault_id: str,
request_body: DeleteRequest,
Expand All @@ -395,14 +356,14 @@ def delete(
ids = request_body.ids
logger.info(f"Deleting {len(ids)} documents")
quoted_ids = [urllib.parse.quote(id) for id in ids]
index.delete(ids=quoted_ids, namespace=namespace)
await vector_database.delete(ids=quoted_ids, namespace=namespace)
logger.info(f"Deleted {len(ids)} documents")

return JSONResponse(status_code=status.HTTP_200_OK, content={"status": "success"})


@app.post("/v1/{vault_id}/search")
def semantic_search(
async def semantic_search(
request: Request,
vault_id: str,
request_body: SearchRequest,
Expand All @@ -422,16 +383,14 @@ def semantic_search(

logger.info(f"Query {request_body.query} created embedding, querying index")

query_response = index.query(
query_response = await vector_database.search(
top_k=top_k,
include_values=True,
include_metadata=True,
vector=query_embedding,
namespace=namespace,
)

similarities = []
for match in query_response.matches:
for match in query_response:
logger.debug(f"Match id: {match.id}")
decoded_id = urllib.parse.unquote(match.id)
logger.debug(f"Decoded id: {decoded_id}")
Expand Down
60 changes: 60 additions & 0 deletions search/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
from typing import Coroutine, List, Optional, Tuple

from pandas import DataFrame

# TODO: make this less Pinecone specific
class VectorDatabase(ABC):
"""
Base class for all vector databases
"""

@abstractmethod
async def fetch(
self, ids: List[str], namespace: Optional[str] = None
) -> List[dict]:
"""
:param ids: list of ids
:param namespace: namespace
:return: list of vectors
"""
raise NotImplementedError

@abstractmethod
async def update(
self,
df: DataFrame,
namespace: Optional[str] = None,
batch_size: Optional[int] = 100,
) -> Coroutine:
"""
:param vectors: list of vectors
:param namespace: namespace
"""
raise NotImplementedError

@abstractmethod
async def delete(self, ids: List[str], namespace: Optional[str] = None) -> None:
"""
:param ids: list of ids
"""
raise NotImplementedError

@abstractmethod
async def search(
self, vector: List[float], top_k: Optional[int], namespace: Optional[str] = None
) -> List[dict]:
"""
:param vector: vector
:param top_k: top k
:param namespace: namespace
:return: list of vectors
"""
raise NotImplementedError

@abstractmethod
async def clear(self, namespace: Optional[str] = None) -> None:
"""
:param namespace: namespace
"""
raise NotImplementedError
Loading

0 comments on commit a597eec

Please sign in to comment.