Skip to content

Commit

Permalink
Add type checking of ragstack-llama-index
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jul 25, 2024
1 parent e78d791 commit e9fe4f2
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 37 deletions.
20 changes: 20 additions & 0 deletions libs/llamaindex/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,26 @@ google = ["llama-index-llms-gemini", "llama-index-multi-modal-llms-gemini", "lla
azure = ["llama-index-llms-azure-openai", "llama-index-embeddings-azure-openai"]
bedrock = ["llama-index-llms-bedrock", "llama-index-embeddings-bedrock"]

[tool.poetry.group.dev.dependencies]
mypy = "^1.11.0"

[tool.poetry.group.test.dependencies]
ragstack-ai-tests-utils = { path = "../tests-utils", develop = true }
ragstack-ai-colbert = { path = "../colbert", develop = true }

[tool.mypy]
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
follow_imports = "normal"
ignore_missing_imports = true
no_implicit_reexport = true
show_error_codes = true
show_error_context = true
strict_equality = true
strict_optional = true
warn_redundant_casts = true
warn_return_any = true
warn_unused_ignores = true
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
Expand Down Expand Up @@ -27,7 +27,7 @@ def __init__(
retriever: ColbertBaseRetriever,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
callback_manager: Optional[CallbackManager] = None,
object_map: Optional[dict] = None,
object_map: Optional[Dict[str, Any]] = None,
verbose: bool = False,
query_maxlen: int = -1,
) -> None:
Expand All @@ -51,6 +51,6 @@ def _retrieve(
query_maxlen=self._query_maxlen,
)
return [
NodeWithScore(node=TextNode(text=c.text, metadata=c.metadata), score=s)
NodeWithScore(node=TextNode(text=c.text, extra_info=c.metadata), score=s)
for (c, s) in chunk_scores
]
3 changes: 2 additions & 1 deletion libs/llamaindex/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from _pytest.fixtures import FixtureRequest
from cassandra.cluster import Session
from ragstack_tests_utils import AstraDBTestStore, LocalCassandraTestStore

Expand All @@ -17,7 +18,7 @@ def astra_db() -> AstraDBTestStore:


@pytest.fixture()
def session(request) -> Session:
def session(request: FixtureRequest) -> Session:
test_store = request.getfixturevalue(request.param)
session = test_store.create_cassandra_session()
session.default_timeout = 180
Expand Down
43 changes: 23 additions & 20 deletions libs/llamaindex/tests/integration_tests/test_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from cassandra.cluster import Session
from llama_index.core import Settings, get_response_synthesizer
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.llms import MockLLM
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.schema import Document, NodeWithScore
from llama_index.core.schema import Document, NodeWithScore, QueryBundle
from llama_index.core.text_splitter import SentenceSplitter
from ragstack_colbert import (
CassandraDatabase,
Expand All @@ -20,7 +21,7 @@
logging.getLogger("cassandra").setLevel(logging.ERROR)


def validate_retrieval(results: List[NodeWithScore], key_value: str):
def validate_retrieval(results: List[NodeWithScore], key_value: str) -> bool:
passed = False
for result in results:
if key_value in result.text:
Expand All @@ -29,7 +30,7 @@ def validate_retrieval(results: List[NodeWithScore], key_value: str):


@pytest.mark.parametrize("session", ["astra_db"], indirect=["session"]) # "cassandra",
def test_sync(session: Session):
def test_sync(session: Session) -> None:
table_name = "LlamaIndex_colbert_sync"

batch_size = 5 # 640 recommended for production use
Expand All @@ -47,15 +48,15 @@ def test_sync(session: Session):
embedding_model=embedding_model,
)

docs: List[Document] = []
docs = []
docs.append(
Document(
text=TestData.marine_animals_text(), metadata={"name": "marine_animals"}
text=TestData.marine_animals_text(), extra_info={"name": "marine_animals"}
)
)
docs.append(
Document(
text=TestData.nebula_voyager_text(), metadata={"name": "nebula_voyager"}
text=TestData.nebula_voyager_text(), extra_info={"name": "nebula_voyager"}
)
)

Expand All @@ -64,20 +65,20 @@ def test_sync(session: Session):

nodes = pipeline.run(documents=docs)

docs: Dict[str, Tuple[List[str], List[Metadata]]] = {}
docs2: Dict[str, Tuple[List[str], List[Metadata]]] = {}

for node in nodes:
doc_id = node.metadata["name"]
if doc_id not in docs:
docs[doc_id] = ([], [])
docs[doc_id][0].append(node.text)
docs[doc_id][1].append(node.metadata)
if doc_id not in docs2:
docs2[doc_id] = ([], [])
docs2[doc_id][0].append(node.text)
docs2[doc_id][1].append(node.metadata)

logging.debug("Starting to embed ColBERT docs and save them to the database")

for doc_id in docs:
texts = docs[doc_id][0]
metadatas = docs[doc_id][1]
for doc_id in docs2:
texts = docs2[doc_id][0]
metadatas = docs2[doc_id][1]

logging.debug("processing %s that has %s chunks", doc_id, len(texts))

Expand All @@ -87,22 +88,24 @@ def test_sync(session: Session):
retriever=vector_store.as_retriever(), similarity_top_k=5
)

Settings.llm = None
Settings.llm = MockLLM()

response_synthesizer = get_response_synthesizer()

pipeline = RetrieverQueryEngine(
pipeline2 = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=response_synthesizer,
)

results = pipeline.retrieve("Who developed the Astroflux Navigator?")
results = pipeline2.retrieve(QueryBundle("Who developed the Astroflux Navigator?"))
assert validate_retrieval(results, key_value="Astroflux Navigator")

results = pipeline.retrieve(
"Describe the phenomena known as 'Chrono-spatial Echoes'"
results = pipeline2.retrieve(
QueryBundle("Describe the phenomena known as 'Chrono-spatial Echoes'")
)
assert validate_retrieval(results, key_value="Chrono-spatial Echoes")

results = pipeline.retrieve("How do anglerfish adapt to the deep ocean's darkness?")
results = pipeline2.retrieve(
QueryBundle("How do anglerfish adapt to the deep ocean's darkness?")
)
assert validate_retrieval(results, key_value="anglerfish")
11 changes: 6 additions & 5 deletions libs/llamaindex/tests/unit_tests/test_import.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import importlib
from typing import Any, Callable


def test_import():
def test_import() -> None:
import astrapy # noqa: F401
import cassio # noqa: F401
import openai # noqa: F401
Expand All @@ -11,25 +12,25 @@ def test_import():
from llama_index.vector_stores.cassandra import CassandraVectorStore # noqa: F401


def check_no_import(fn: callable):
def check_no_import(fn: Callable[[], Any]) -> None:
try:
fn()
raise RuntimeError("Should have failed to import")
except ImportError:
pass


def test_not_import():
def test_not_import() -> None:
check_no_import(lambda: importlib.import_module("langchain.vectorstores"))
check_no_import(lambda: importlib.import_module("langchain_astradb"))
check_no_import(lambda: importlib.import_module("langchain_core"))
check_no_import(lambda: importlib.import_module("langsmith"))


def test_meta():
def test_meta() -> None:
from importlib import metadata

def check_meta(package: str):
def check_meta(package: str) -> None:
meta = metadata.metadata(package)
assert meta["version"]
assert meta["license"] == "BUSL-1.1"
Expand Down
22 changes: 14 additions & 8 deletions libs/llamaindex/tox.ini
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
[tox]
min_version = 4.0
envlist = py311
envlist = type, unit-tests, integration-tests

[testenv]
description = install dependencies
skip_install = true
allowlist_externals = poetry
commands_pre =
poetry env use system
poetry install -E colbert

[testenv:unit-tests]
description = run unit tests
deps =
poetry
commands =
poetry install
poetry build
poetry run pytest --disable-warnings {toxinidir}/tests/unit_tests

[testenv:integration-tests]
description = run integration tests
deps =
poetry
pass_env =
ASTRA_DB_TOKEN
ASTRA_DB_ID
ASTRA_DB_ENV
commands =
poetry install -E colbert
poetry run pytest --disable-warnings {toxinidir}/tests/integration_tests

[testenv:type]
description = run type checking
commands =
poetry run mypy {toxinidir}

0 comments on commit e9fe4f2

Please sign in to comment.