Skip to content

Commit

Permalink
0.6.2: fix: id generation collision
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Feb 13, 2023
1 parent 850c5f8 commit 415008a
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 10 deletions.
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ LOCAL_PORT="8000"
install: ## [DEVELOPMENT] Install the API dependencies
virtualenv env; \
. env/bin/activate; \
pip install .[all]
pip install .[all]; \
pip install -r requirements-test.txt
@echo "Done, run '\033[0;31msource env/bin/activate\033[0m' to activate the virtual environment"

run: ## [DEVELOPMENT] Run the API
Expand All @@ -18,6 +19,7 @@ test: ## [Local development] Run tests with pytest.
python3 -m pytest -s test_main.py::test_clear; \
python3 -m pytest -s test_main.py::test_semantic_search; \
python3 -m pytest -s test_main.py::test_refresh_small_documents; \
python3 -m pytest -s test_main.py::test_sync_no_id_collision; \
python3 -m pytest -s test_main.py::test_embed; \
python3 -m pytest -s test_main.py::test_embed_large_text; \
python3 -m pytest -s test_main.py::test_upload; \
Expand Down
19 changes: 12 additions & 7 deletions embedbase/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tenacity.stop import stop_after_attempt
import requests
from typing import List, Tuple

import uuid

settings = get_settings()
MAX_DOCUMENT_LENGTH = int(os.environ.get("MAX_DOCUMENT_LENGTH", "1000"))
Expand All @@ -50,6 +50,7 @@
if settings.sentry:
logger.info("Enabling Sentry")
import sentry_sdk

sentry_sdk.init(
dsn=settings.sentry,
# Set traces_sample_rate to 1.0 to capture 100%
Expand Down Expand Up @@ -188,7 +189,6 @@ def embed(
after=after_log(logger, logging.ERROR),
stop=stop_after_attempt(3),
)

def get_namespace(request: Request, vault_id: str) -> str:
return f"{request.scope.get('uid')}/{vault_id}"

Expand Down Expand Up @@ -268,13 +268,13 @@ async def _fetch(ids) -> List[dict]:
logger.error(f"Error fetching {ids}: {e}", exc_info=True)
raise e

existing_documents = await asyncio.gather(*[_fetch(ids) for ids in ids_to_fetch])
existing_documents = await asyncio.gather(
*[_fetch(ids) for ids in ids_to_fetch]
)
flat_existing_documents = itertools.chain.from_iterable(
[doc.vectors.values() for doc in existing_documents]
)

# TODO: might do also with https://docs.pinecone.io/docs/metadata-filtering#querying-an-index-with-metadata-filters

# remove rows that have the same hash
exisiting_contents = []
for doc in flat_existing_documents:
Expand All @@ -287,8 +287,13 @@ async def _fetch(ids) -> List[dict]:
)
]
else:
# generate ids using hash + time
df.id = df.hash.apply(lambda x: f"{x}-{int(time.time() * 1000)}")
# generate ids using hash of uuid + time to avoid collisions
df.id = df.apply(
lambda x: hashlib.sha256(
(str(uuid.uuid4()) + str(time.time())).encode()
).hexdigest(),
axis=1,
)

diff = df_length - len(df)

Expand Down
33 changes: 32 additions & 1 deletion embedbase/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,38 @@ async def test_refresh_small_documents():
},
)
assert response.status_code == 200
assert response.json().get("status", "") == "success"
json_response = response.json()
assert json_response.get("status", "") == "success"
assert len(json_response.get("inserted_ids")) == 10


@pytest.mark.asyncio
async def test_sync_no_id_collision():
df = pd.DataFrame(
[
"foo"
for _ in range(10)
],
columns=["text"],
)
async with AsyncClient(app=app, base_url="http://localhost:8000") as client:
response = await client.post(
"/v1/dev",
json={
"documents": [
{
"data": text,
}
for i, text in enumerate(df.text.tolist())
],
},
)
assert response.status_code == 200
json_response = response.json()
assert json_response.get("status", "") == "success"
# make sure all ids are unique
assert len(set(json_response.get("inserted_ids"))) == 10



def test_embed():
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
name="embedbase",
packages=find_packages(),
include_package_data=True,
version="0.6.1",
version="0.6.2",
description="Open-source API for to easily create, store, and retrieve embeddings",
install_requires=install_requires,
extras_require=extras_require,
Expand Down

0 comments on commit 415008a

Please sign in to comment.