Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type checking of ragstack-llama-index #614

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions libs/llamaindex/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,26 @@ google = ["llama-index-llms-gemini", "llama-index-multi-modal-llms-gemini", "lla
azure = ["llama-index-llms-azure-openai", "llama-index-embeddings-azure-openai"]
bedrock = ["llama-index-llms-bedrock", "llama-index-embeddings-bedrock"]

[tool.poetry.group.dev.dependencies]
mypy = "^1.11.0"

[tool.poetry.group.test.dependencies]
ragstack-ai-tests-utils = { path = "../tests-utils", develop = true }
ragstack-ai-colbert = { path = "../colbert", develop = true }

[tool.mypy]
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
follow_imports = "normal"
ignore_missing_imports = true
no_implicit_reexport = true
show_error_codes = true
show_error_context = true
strict_equality = true
strict_optional = true
warn_redundant_casts = true
warn_return_any = true
warn_unused_ignores = true
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
Expand Down Expand Up @@ -27,7 +27,7 @@ def __init__(
retriever: ColbertBaseRetriever,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
callback_manager: Optional[CallbackManager] = None,
object_map: Optional[dict] = None,
object_map: Optional[Dict[str, Any]] = None,
verbose: bool = False,
query_maxlen: int = -1,
) -> None:
Expand All @@ -51,6 +51,6 @@ def _retrieve(
query_maxlen=self._query_maxlen,
)
return [
NodeWithScore(node=TextNode(text=c.text, metadata=c.metadata), score=s)
NodeWithScore(node=TextNode(text=c.text, extra_info=c.metadata), score=s)
for (c, s) in chunk_scores
]
3 changes: 2 additions & 1 deletion libs/llamaindex/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from _pytest.fixtures import FixtureRequest
from cassandra.cluster import Session
from ragstack_tests_utils import AstraDBTestStore, LocalCassandraTestStore

Expand All @@ -17,7 +18,7 @@ def astra_db() -> AstraDBTestStore:


@pytest.fixture()
def session(request) -> Session:
def session(request: FixtureRequest) -> Session:
test_store = request.getfixturevalue(request.param)
session = test_store.create_cassandra_session()
session.default_timeout = 180
Expand Down
43 changes: 23 additions & 20 deletions libs/llamaindex/tests/integration_tests/test_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from cassandra.cluster import Session
from llama_index.core import Settings, get_response_synthesizer
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.llms import MockLLM
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.schema import Document, NodeWithScore
from llama_index.core.schema import Document, NodeWithScore, QueryBundle
from llama_index.core.text_splitter import SentenceSplitter
from ragstack_colbert import (
CassandraDatabase,
Expand All @@ -20,7 +21,7 @@
logging.getLogger("cassandra").setLevel(logging.ERROR)


def validate_retrieval(results: List[NodeWithScore], key_value: str):
def validate_retrieval(results: List[NodeWithScore], key_value: str) -> bool:
passed = False
for result in results:
if key_value in result.text:
Expand All @@ -29,7 +30,7 @@ def validate_retrieval(results: List[NodeWithScore], key_value: str):


@pytest.mark.parametrize("session", ["astra_db"], indirect=["session"]) # "cassandra",
def test_sync(session: Session):
def test_sync(session: Session) -> None:
table_name = "LlamaIndex_colbert_sync"

batch_size = 5 # 640 recommended for production use
Expand All @@ -47,15 +48,15 @@ def test_sync(session: Session):
embedding_model=embedding_model,
)

docs: List[Document] = []
docs = []
docs.append(
Document(
text=TestData.marine_animals_text(), metadata={"name": "marine_animals"}
text=TestData.marine_animals_text(), extra_info={"name": "marine_animals"}
)
)
docs.append(
Document(
text=TestData.nebula_voyager_text(), metadata={"name": "nebula_voyager"}
text=TestData.nebula_voyager_text(), extra_info={"name": "nebula_voyager"}
)
)

Expand All @@ -64,20 +65,20 @@ def test_sync(session: Session):

nodes = pipeline.run(documents=docs)

docs: Dict[str, Tuple[List[str], List[Metadata]]] = {}
docs2: Dict[str, Tuple[List[str], List[Metadata]]] = {}

for node in nodes:
doc_id = node.metadata["name"]
if doc_id not in docs:
docs[doc_id] = ([], [])
docs[doc_id][0].append(node.text)
docs[doc_id][1].append(node.metadata)
if doc_id not in docs2:
docs2[doc_id] = ([], [])
docs2[doc_id][0].append(node.text)
docs2[doc_id][1].append(node.metadata)

logging.debug("Starting to embed ColBERT docs and save them to the database")

for doc_id in docs:
texts = docs[doc_id][0]
metadatas = docs[doc_id][1]
for doc_id in docs2:
texts = docs2[doc_id][0]
metadatas = docs2[doc_id][1]

logging.debug("processing %s that has %s chunks", doc_id, len(texts))

Expand All @@ -87,22 +88,24 @@ def test_sync(session: Session):
retriever=vector_store.as_retriever(), similarity_top_k=5
)

Settings.llm = None
Settings.llm = MockLLM()

response_synthesizer = get_response_synthesizer()

pipeline = RetrieverQueryEngine(
pipeline2 = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=response_synthesizer,
)

results = pipeline.retrieve("Who developed the Astroflux Navigator?")
results = pipeline2.retrieve(QueryBundle("Who developed the Astroflux Navigator?"))
assert validate_retrieval(results, key_value="Astroflux Navigator")

results = pipeline.retrieve(
"Describe the phenomena known as 'Chrono-spatial Echoes'"
results = pipeline2.retrieve(
QueryBundle("Describe the phenomena known as 'Chrono-spatial Echoes'")
)
assert validate_retrieval(results, key_value="Chrono-spatial Echoes")

results = pipeline.retrieve("How do anglerfish adapt to the deep ocean's darkness?")
results = pipeline2.retrieve(
QueryBundle("How do anglerfish adapt to the deep ocean's darkness?")
)
assert validate_retrieval(results, key_value="anglerfish")
11 changes: 6 additions & 5 deletions libs/llamaindex/tests/unit_tests/test_import.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import importlib
from typing import Any, Callable


def test_import():
def test_import() -> None:
import astrapy # noqa: F401
import cassio # noqa: F401
import openai # noqa: F401
Expand All @@ -11,25 +12,25 @@ def test_import():
from llama_index.vector_stores.cassandra import CassandraVectorStore # noqa: F401


def check_no_import(fn: callable):
def check_no_import(fn: Callable[[], Any]) -> None:
try:
fn()
raise RuntimeError("Should have failed to import")
except ImportError:
pass


def test_not_import():
def test_not_import() -> None:
check_no_import(lambda: importlib.import_module("langchain.vectorstores"))
check_no_import(lambda: importlib.import_module("langchain_astradb"))
check_no_import(lambda: importlib.import_module("langchain_core"))
check_no_import(lambda: importlib.import_module("langsmith"))


def test_meta():
def test_meta() -> None:
from importlib import metadata

def check_meta(package: str):
def check_meta(package: str) -> None:
meta = metadata.metadata(package)
assert meta["version"]
assert meta["license"] == "BUSL-1.1"
Expand Down
22 changes: 14 additions & 8 deletions libs/llamaindex/tox.ini
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
[tox]
min_version = 4.0
envlist = py311
envlist = type, unit-tests, integration-tests

[testenv]
description = install dependencies
skip_install = true
allowlist_externals = poetry
commands_pre =
poetry env use system
poetry install -E colbert

[testenv:unit-tests]
description = run unit tests
deps =
poetry
commands =
poetry install
poetry build
poetry run pytest --disable-warnings {toxinidir}/tests/unit_tests

[testenv:integration-tests]
description = run integration tests
deps =
poetry
pass_env =
ASTRA_DB_TOKEN
ASTRA_DB_ID
ASTRA_DB_ENV
commands =
poetry install -E colbert
poetry run pytest --disable-warnings {toxinidir}/tests/integration_tests

[testenv:type]
description = run type checking
commands =
poetry run mypy {toxinidir}
Loading