From e9fe4f22f21cc4d09970b9235599098028af99da Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 25 Jul 2024 02:15:55 +0200 Subject: [PATCH] Add type checking of ragstack-llama-index --- libs/llamaindex/pyproject.toml | 20 +++++++++ .../colbert/colbert_retriever.py | 6 +-- .../tests/integration_tests/conftest.py | 3 +- .../tests/integration_tests/test_colbert.py | 43 ++++++++++--------- .../tests/unit_tests/test_import.py | 11 ++--- libs/llamaindex/tox.ini | 22 ++++++---- 6 files changed, 68 insertions(+), 37 deletions(-) diff --git a/libs/llamaindex/pyproject.toml b/libs/llamaindex/pyproject.toml index a86bcca93..9c68b5c40 100644 --- a/libs/llamaindex/pyproject.toml +++ b/libs/llamaindex/pyproject.toml @@ -44,6 +44,26 @@ google = ["llama-index-llms-gemini", "llama-index-multi-modal-llms-gemini", "lla azure = ["llama-index-llms-azure-openai", "llama-index-embeddings-azure-openai"] bedrock = ["llama-index-llms-bedrock", "llama-index-embeddings-bedrock"] +[tool.poetry.group.dev.dependencies] +mypy = "^1.11.0" + [tool.poetry.group.test.dependencies] ragstack-ai-tests-utils = { path = "../tests-utils", develop = true } ragstack-ai-colbert = { path = "../colbert", develop = true } + +[tool.mypy] +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +follow_imports = "normal" +ignore_missing_imports = true +no_implicit_reexport = true +show_error_codes = true +show_error_context = true +strict_equality = true +strict_optional = true +warn_redundant_casts = true +warn_return_any = true +warn_unused_ignores = true diff --git a/libs/llamaindex/ragstack_llamaindex/colbert/colbert_retriever.py b/libs/llamaindex/ragstack_llamaindex/colbert/colbert_retriever.py index 274fac190..5b8e2ef7c 100644 --- a/libs/llamaindex/ragstack_llamaindex/colbert/colbert_retriever.py +++ b/libs/llamaindex/ragstack_llamaindex/colbert/colbert_retriever.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from llama_index.core.callbacks.base import CallbackManager from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K @@ -27,7 +27,7 @@ def __init__( retriever: ColbertBaseRetriever, similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, callback_manager: Optional[CallbackManager] = None, - object_map: Optional[dict] = None, + object_map: Optional[Dict[str, Any]] = None, verbose: bool = False, query_maxlen: int = -1, ) -> None: @@ -51,6 +51,6 @@ def _retrieve( query_maxlen=self._query_maxlen, ) return [ - NodeWithScore(node=TextNode(text=c.text, metadata=c.metadata), score=s) + NodeWithScore(node=TextNode(text=c.text, extra_info=c.metadata), score=s) for (c, s) in chunk_scores ] diff --git a/libs/llamaindex/tests/integration_tests/conftest.py b/libs/llamaindex/tests/integration_tests/conftest.py index f0c8c68e0..f2d61325d 100644 --- a/libs/llamaindex/tests/integration_tests/conftest.py +++ b/libs/llamaindex/tests/integration_tests/conftest.py @@ -1,4 +1,5 @@ import pytest +from _pytest.fixtures import FixtureRequest from cassandra.cluster import Session from ragstack_tests_utils import AstraDBTestStore, LocalCassandraTestStore @@ -17,7 +18,7 @@ def astra_db() -> AstraDBTestStore: @pytest.fixture() -def session(request) -> Session: +def session(request: FixtureRequest) -> Session: test_store = request.getfixturevalue(request.param) session = test_store.create_cassandra_session() session.default_timeout = 180 diff --git a/libs/llamaindex/tests/integration_tests/test_colbert.py b/libs/llamaindex/tests/integration_tests/test_colbert.py index 7aa473d04..c23df3e1c 100644 --- a/libs/llamaindex/tests/integration_tests/test_colbert.py +++ b/libs/llamaindex/tests/integration_tests/test_colbert.py @@ -5,8 +5,9 @@ from cassandra.cluster import Session from llama_index.core import Settings, get_response_synthesizer from llama_index.core.ingestion import IngestionPipeline +from llama_index.core.llms import MockLLM from llama_index.core.query_engine import RetrieverQueryEngine -from llama_index.core.schema import Document, NodeWithScore +from llama_index.core.schema import Document, NodeWithScore, QueryBundle from llama_index.core.text_splitter import SentenceSplitter from ragstack_colbert import ( CassandraDatabase, @@ -20,7 +21,7 @@ logging.getLogger("cassandra").setLevel(logging.ERROR) -def validate_retrieval(results: List[NodeWithScore], key_value: str): +def validate_retrieval(results: List[NodeWithScore], key_value: str) -> bool: passed = False for result in results: if key_value in result.text: @@ -29,7 +30,7 @@ def validate_retrieval(results: List[NodeWithScore], key_value: str): @pytest.mark.parametrize("session", ["astra_db"], indirect=["session"]) # "cassandra", -def test_sync(session: Session): +def test_sync(session: Session) -> None: table_name = "LlamaIndex_colbert_sync" batch_size = 5 # 640 recommended for production use @@ -47,15 +48,15 @@ def test_sync(session: Session): embedding_model=embedding_model, ) - docs: List[Document] = [] + docs = [] docs.append( Document( - text=TestData.marine_animals_text(), metadata={"name": "marine_animals"} + text=TestData.marine_animals_text(), extra_info={"name": "marine_animals"} ) ) docs.append( Document( - text=TestData.nebula_voyager_text(), metadata={"name": "nebula_voyager"} + text=TestData.nebula_voyager_text(), extra_info={"name": "nebula_voyager"} ) ) @@ -64,20 +65,20 @@ def test_sync(session: Session): nodes = pipeline.run(documents=docs) - docs: Dict[str, Tuple[List[str], List[Metadata]]] = {} + docs2: Dict[str, Tuple[List[str], List[Metadata]]] = {} for node in nodes: doc_id = node.metadata["name"] - if doc_id not in docs: - docs[doc_id] = ([], []) - docs[doc_id][0].append(node.text) - docs[doc_id][1].append(node.metadata) + if doc_id not in docs2: + docs2[doc_id] = ([], []) + docs2[doc_id][0].append(node.text) + docs2[doc_id][1].append(node.metadata) logging.debug("Starting to embed ColBERT docs and save them to the database") - for doc_id in docs: - texts = docs[doc_id][0] - metadatas = docs[doc_id][1] + for doc_id in docs2: + texts = docs2[doc_id][0] + metadatas = docs2[doc_id][1] logging.debug("processing %s that has %s chunks", doc_id, len(texts)) @@ -87,22 +88,24 @@ def test_sync(session: Session): retriever=vector_store.as_retriever(), similarity_top_k=5 ) - Settings.llm = None + Settings.llm = MockLLM() response_synthesizer = get_response_synthesizer() - pipeline = RetrieverQueryEngine( + pipeline2 = RetrieverQueryEngine( retriever=retriever, response_synthesizer=response_synthesizer, ) - results = pipeline.retrieve("Who developed the Astroflux Navigator?") + results = pipeline2.retrieve(QueryBundle("Who developed the Astroflux Navigator?")) assert validate_retrieval(results, key_value="Astroflux Navigator") - results = pipeline.retrieve( - "Describe the phenomena known as 'Chrono-spatial Echoes'" + results = pipeline2.retrieve( + QueryBundle("Describe the phenomena known as 'Chrono-spatial Echoes'") ) assert validate_retrieval(results, key_value="Chrono-spatial Echoes") - results = pipeline.retrieve("How do anglerfish adapt to the deep ocean's darkness?") + results = pipeline2.retrieve( + QueryBundle("How do anglerfish adapt to the deep ocean's darkness?") + ) assert validate_retrieval(results, key_value="anglerfish") diff --git a/libs/llamaindex/tests/unit_tests/test_import.py b/libs/llamaindex/tests/unit_tests/test_import.py index c8e4a8337..c9200bdb4 100644 --- a/libs/llamaindex/tests/unit_tests/test_import.py +++ b/libs/llamaindex/tests/unit_tests/test_import.py @@ -1,7 +1,8 @@ import importlib +from typing import Any, Callable -def test_import(): +def test_import() -> None: import astrapy # noqa: F401 import cassio # noqa: F401 import openai # noqa: F401 @@ -11,7 +12,7 @@ def test_import(): from llama_index.vector_stores.cassandra import CassandraVectorStore # noqa: F401 -def check_no_import(fn: callable): +def check_no_import(fn: Callable[[], Any]) -> None: try: fn() raise RuntimeError("Should have failed to import") @@ -19,17 +20,17 @@ def check_no_import(fn: callable): pass -def test_not_import(): +def test_not_import() -> None: check_no_import(lambda: importlib.import_module("langchain.vectorstores")) check_no_import(lambda: importlib.import_module("langchain_astradb")) check_no_import(lambda: importlib.import_module("langchain_core")) check_no_import(lambda: importlib.import_module("langsmith")) -def test_meta(): +def test_meta() -> None: from importlib import metadata - def check_meta(package: str): + def check_meta(package: str) -> None: meta = metadata.metadata(package) assert meta["version"] assert meta["license"] == "BUSL-1.1" diff --git a/libs/llamaindex/tox.ini b/libs/llamaindex/tox.ini index e3fdf2da9..4e67d42e7 100644 --- a/libs/llamaindex/tox.ini +++ b/libs/llamaindex/tox.ini @@ -1,24 +1,30 @@ [tox] min_version = 4.0 -envlist = py311 +envlist = type, unit-tests, integration-tests + +[testenv] +description = install dependencies +skip_install = true +allowlist_externals = poetry +commands_pre = + poetry env use system + poetry install -E colbert [testenv:unit-tests] description = run unit tests -deps = - poetry commands = - poetry install - poetry build poetry run pytest --disable-warnings {toxinidir}/tests/unit_tests [testenv:integration-tests] description = run integration tests -deps = - poetry pass_env = ASTRA_DB_TOKEN ASTRA_DB_ID ASTRA_DB_ENV commands = - poetry install -E colbert poetry run pytest --disable-warnings {toxinidir}/tests/integration_tests + +[testenv:type] +description = run type checking +commands = + poetry run mypy {toxinidir}