Skip to content

Commit

Permalink
0.6.8: feat: prevent clear data retention; fix: clearing index typo
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Feb 17, 2023
1 parent efb99fc commit cb09d85
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 11 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ test: ## [Local development] Run tests with pytest.
python3 -m pytest -s test_main.py::test_embed; \
python3 -m pytest -s test_main.py::test_embed_large_text; \
python3 -m pytest -s test_main.py::test_upload; \
python3 -m pytest -s test_main.py::test_ignore_document_that_didnt_change
python3 -m pytest -s test_main.py::test_ignore_document_that_didnt_change; \
python3 -m pytest -s test_main.py::test_save_clear_data
@echo "Done testing"

docker/build/prod: ## [Local development] Build the docker image.
Expand Down
15 changes: 9 additions & 6 deletions embedbase/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
middlewares = []
if settings.middlewares:
from starlette.middleware import Middleware

for i, m in enumerate(settings.middlewares):
# import python file at path m
# and add the first class found to the list
Expand Down Expand Up @@ -134,8 +135,6 @@ async def firebase_auth(request: Request, call_next) -> Tuple[str, str]:
return response




app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand Down Expand Up @@ -200,6 +199,7 @@ def embed(
def get_namespace(request: Request, vault_id: str) -> str:
return f"{request.scope.get('uid')}/{vault_id}"


@app.get("/v1/{vault_id}/clear")
async def clear(
request: Request,
Expand Down Expand Up @@ -283,10 +283,8 @@ async def _fetch(ids) -> List[dict]:
)

# remove rows that have the same hash
exisiting_contents = []
for doc in flat_existing_documents:
existing_hashes.append(doc.id)
exisiting_contents.append(doc.get("metadata", {}).get("data"))
df = df[
~df.apply(
lambda x: x.hash in existing_hashes,
Expand Down Expand Up @@ -334,7 +332,12 @@ async def _fetch(ids) -> List[dict]:
# # merge s column into a single column , ignore index
# df.embedding = s.apply(lambda x: x.tolist(), axis=1)
# TODO: problem is that pinecone doesn't support this large of an input
await vector_database.update(df, namespace, batch_size=UPLOAD_BATCH_SIZE)
await vector_database.update(
df,
namespace,
batch_size=UPLOAD_BATCH_SIZE,
save_clear_data=settings.save_clear_data,
)

logger.info(f"Indexed & uploaded {len(df)} sentences")
end_time = time.time()
Expand Down Expand Up @@ -406,7 +409,7 @@ async def semantic_search(
{
"score": match.score,
"id": decoded_id,
"data": match.metadata.get("data", None),
"data": match.get("metadata", {}).get("data", None),
}
)
return JSONResponse(
Expand Down
3 changes: 3 additions & 0 deletions embedbase/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ async def update(
df: DataFrame,
namespace: Optional[str] = None,
batch_size: Optional[int] = 100,
save_clear_data: bool = True,
) -> Coroutine:
"""
:param vectors: list of vectors
:param namespace: namespace
:param batch_size: batch size
:param save_clear_data: save clear data
"""
raise NotImplementedError

Expand Down
8 changes: 7 additions & 1 deletion embedbase/pinecone_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ async def update(
df: DataFrame,
namespace: Optional[str] = None,
batch_size: Optional[int] = 100,
save_clear_data: bool = True,
) -> Coroutine:
"""
:param vectors: list of vectors
:param namespace: namespace
:param batch_size: batch size
:param save_clear_data: save clear data
"""
df_batcher = BatchGenerator(batch_size)
batches = [batch_df for batch_df in df_batcher(df)]
Expand All @@ -62,6 +65,9 @@ def _insert(batch_df: DataFrame):
}
for data in batch_df.data
],
) if save_clear_data else zip(
batch_df.id.apply(urllib.parse.quote).tolist(),
batch_df.embedding,
),
namespace=namespace or self.default_namespace,
async_req=True,
Expand Down Expand Up @@ -98,4 +104,4 @@ async def clear(self, namespace: Optional[str]) -> None:
:param namespace: namespace
"""
# HACK: stupid hack "%" because pinecone doc is incorrect and u need to pass ids & clear_all
self.index.delete(ids=["%"], clear_all=True, namespace=namespace or self.default_namespace)
self.index.delete(delete_all=True, namespace=namespace or self.default_namespace)
1 change: 1 addition & 0 deletions embedbase/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Settings(YamlModel):
sentry: typing.Optional[str] = None
firebase_service_account_path: typing.Optional[str] = None
middlewares: typing.Optional[typing.List[str]] = None
save_clear_data: bool = True

@lru_cache()
def get_settings():
Expand Down
87 changes: 85 additions & 2 deletions embedbase/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,65 @@
from embedbase.pinecone_db import Pinecone

from embedbase.settings import get_settings
from .api import app, embed, no_batch_embed
from .api import app, embed, no_batch_embed, settings
import pandas as pd
import math
from random import randint
import numpy as np

@pytest.mark.asyncio
async def test_clear():
df = pd.DataFrame(
[
"".join(
[
chr(math.floor(97 + 26 * np.random.rand()))
for _ in range(randint(500, 800))
]
)
for _ in range(10)
],
columns=["text"],
)
async with AsyncClient(app=app, base_url="http://localhost:8000") as client:
response = await client.post(
"/v1/dev",
json={
"documents": [
{
"data": text,
}
for i, text in enumerate(df.text.tolist())
],
},
)
assert response.status_code == 200
json_response = response.json()
assert json_response.get("status", "") == "success"
assert len(json_response.get("inserted_ids")) == 10

async with AsyncClient(app=app, base_url="http://localhost:8000") as client:
response = await client.get(
"/v1/dev/clear",
)
assert response.status_code == 200
assert response.json().get("status", "") == "success"

# search now
async with AsyncClient(app=app, base_url="http://localhost:8000") as client:
response = await client.post("/v1/dev/search", json={"query": "bob"})
assert response.status_code == 200
json_response = response.json()
assert json_response.get("query", "") == "bob"
assert len(json_response.get("similarities")) == 0

@pytest.mark.asyncio
async def test_semantic_search():
async with AsyncClient(app=app, base_url="http://localhost:8000") as client:
response = await client.post("/v1/dev/search", json={"query": "bob"})
assert response.status_code == 200
json_response = response.json()
assert json_response.get("query", "") == "bob"

@pytest.mark.asyncio
async def test_refresh_small_documents():
Expand Down Expand Up @@ -178,4 +217,48 @@ async def test_ignore_document_that_didnt_change():
},
)
assert response.status_code == 200
assert len(response.json().get("ignored_ids")) == 10
assert len(response.json().get("ignored_ids")) == 10


@pytest.mark.asyncio
async def test_save_clear_data():
# clear all
async with AsyncClient(app=app, base_url="http://localhost:8000") as client:
response = await client.get(
"/v1/dev/clear",
)
assert response.status_code == 200
assert response.json().get("status", "") == "success"
df = pd.DataFrame(
[
"bob is a human"
],
columns=["text"],
)
settings.save_clear_data = False
async with AsyncClient(app=app, base_url="http://localhost:8000") as client:
response = await client.post(
"/v1/dev",
json={
"documents": [
{
"data": text,
}
for i, text in enumerate(df.text.tolist())
],
},
)
assert response.status_code == 200
json_response = response.json()
assert json_response.get("status", "") == "success"
assert len(json_response.get("inserted_ids")) == 1
# now search shouldn't have the "data" field in the response
async with AsyncClient(app=app, base_url="http://localhost:8000") as client:
response = await client.post(
"/v1/dev/search",
json={"query": "bob"},
)
assert response.status_code == 200
json_response = response.json()
assert len(json_response.get("similarities")) > 0
assert json_response.get("similarities")[0].get("data") is None
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
name="embedbase",
packages=find_packages(),
include_package_data=True,
version="0.6.7",
version="0.6.8",
description="Open-source API for to easily create, store, and retrieve embeddings",
install_requires=install_requires,
extras_require=extras_require,
Expand Down

0 comments on commit cb09d85

Please sign in to comment.