Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(chat): fetch examples from chromadb for few-shot learning #21

Merged
merged 16 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aikg/config/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ChatConfig(BaseModel):
When generating sparql:
* Never enclose the sparql in back-quotes

{examples_str}

Use the following format:

Expand Down
5 changes: 4 additions & 1 deletion aikg/config/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class ChromaConfig(BaseModel):
The port of the ChromaDB server.
collection_name:
The name of the ChromaDB collection to store the index in.
collection_examples:
The name of the ChromaDB collection to store examples in.
embedding_model_id:
The HuggingFace ID of the embedding model to use.
batch_size:
Expand All @@ -41,7 +43,8 @@ class ChromaConfig(BaseModel):

host: str = os.environ.get("CHROMA_HOST", "127.0.0.1")
port: int = int(os.environ.get("CHROMA_PORT", "8000"))
collection_name: str = os.environ.get("CHROMA_COLLECTION", "test")
collection_name: str = os.environ.get("CHROMA_COLLECTION", "schema")
collection_examples: str = os.environ.get("CHROMA_EXAMPLES", "examples")
batch_size: int = int(os.environ.get("CHROMA_BATCH_SIZE", "50"))
embedding_model: str = os.environ.get("CHROMA_MODEL", "all-mpnet-base-v2")
persist_directory: str = os.environ.get("CHROMA_PERSIST_DIR", ".chroma/")
4 changes: 2 additions & 2 deletions aikg/flows/chroma_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from typing_extensions import Annotated
import uuid

from chromadb.api import API, Collection
from chromadb.api import ClientAPI, Collection
from dotenv import load_dotenv
from langchain.schema import Document
from more_itertools import chunked
Expand All @@ -53,7 +53,7 @@ def init_chromadb(
collection_name: str,
embedding_model: str,
persist_directory: str,
) -> Tuple[API, Collection]:
) -> Tuple[ClientAPI, Collection]:
"""Prepare chromadb client."""
client = akchroma.setup_client(host, port, persist_directory=persist_directory)
coll = akchroma.setup_collection(client, collection_name, embedding_model)
Expand Down
150 changes: 150 additions & 0 deletions aikg/flows/chroma_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# kg-llm-interface
# Copyright 2023 - Swiss Data Science Center (SDSC)
# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
# Eidgenössische Technische Hochschule Zürich (ETHZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This flow builds a ChromaDB vector index from examples consisting of pairs of questions and SPARQL queries.

For each subject in the target graph, a document is generated. The document consists of:
* A human readable question (document body)
* A corresponding SPARQL query (document metadata)

The documents are then stored in a vector database. The embedding is computed using the document body (questions),
and SPAQRL queries included as metadata. The index is persisted to disk and can be subsequently loaded into memory
for querying."""

from pathlib import Path
from typing import Optional, Tuple
from typing_extensions import Annotated
import uuid
import os

from chromadb.api import ClientAPI, Collection
from dotenv import load_dotenv
from langchain.schema import Document
from more_itertools import chunked
from prefect import flow, task
from prefect import get_run_logger
import typer

from aikg.config import ChromaConfig
from aikg.config.common import parse_yaml_config
import aikg.utils.io as akio
import aikg.utils.chroma as akchroma


@task
def init_chromadb(
host: str,
port: int,
collection_name: str,
embedding_model: str,
persist_directory: str,
) -> Tuple[ClientAPI, Collection]:
"""Prepare chromadb client."""
client = akchroma.setup_client(host, port, persist_directory=persist_directory)
coll = akchroma.setup_collection(client, collection_name, embedding_model)

return client, coll


@task
def index_batch(batch: list[Document]):
"""Sends a batch of document for indexing in the vector store"""
coll.add(
ids=[str(uuid.uuid4()) for _ in batch],
documents=[doc.page_content for doc in batch],
metadatas=[doc.metadata for doc in batch],
)


@task
def get_sparql_examples(dir: Path) -> list[Document]:
# find files
files = []
for file_name in os.listdir(dir):
files.append(os.path.join(dir, file_name))
# provide each file as text stream to be parsed
return [akio.parse_sparql_example(open(ex)) for ex in files]


@flow
def chroma_build_examples_flow(
chroma_input_dir: Path,
chroma_cfg: ChromaConfig = ChromaConfig(),
):
"""Build a ChromaDB vector index from examples.

Parameters
----------
chroma_input_dir:
Directory containing files with example question-query pairs. The files should be in sparql format, with the first line being the question as a comment.
chroma_cfg:
ChromaDB configuration.
"""
load_dotenv()
logger = get_run_logger()
logger.info("INFO Started")
# Connect to external resources
global coll
client, coll = init_chromadb(
chroma_cfg.host,
chroma_cfg.port,
chroma_cfg.collection_examples,
chroma_cfg.embedding_model,
chroma_cfg.persist_directory,
)

# Create subject documents
docs = get_sparql_examples(
dir=chroma_input_dir,
)

# Vectorize and index documents by batches to reduce overhead
logger.info(f"Indexing by batches of {chroma_cfg.batch_size} items")
embed_counter = 0
for batch in chunked(docs, chroma_cfg.batch_size):
embed_counter += len(batch)
for doc in batch:
index_batch(doc)
logger.info(f"Indexed {embed_counter} items.")


def cli(
chroma_input_dir: Annotated[
Path,
typer.Argument(
help="Path to directory with example SPARQL queries",
exists=True,
file_okay=False,
dir_okay=True,
),
],
chroma_cfg_path: Annotated[
Optional[Path],
typer.Option(default=None, help="YAML file with Chroma client configuration."),
] = None,
):
"""Command line wrapper for SPARQL examples to ChromaDB index flow."""
chroma_cfg = (
parse_yaml_config(chroma_cfg_path, ChromaConfig)
if chroma_cfg_path
else ChromaConfig()
)
chroma_build_examples_flow(chroma_input_dir, chroma_cfg)


if __name__ == "__main__":
typer.run(cli)
16 changes: 15 additions & 1 deletion aikg/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from aikg.config import ChatConfig, ChromaConfig, SparqlConfig
from aikg.config.common import parse_yaml_config
from aikg.models import Conversation, Message
from aikg.utils.chat import generate_answer, generate_sparql
from aikg.utils.chat import generate_answer, generate_examples, generate_sparql
from aikg.utils.llm import setup_llm_chain, setup_llm
from aikg.utils.chroma import setup_collection, setup_client
from aikg.utils.rdf import setup_kg, query_kg
Expand All @@ -58,6 +58,11 @@
chroma_config.collection_name,
chroma_config.embedding_model,
)
collection_examples = setup_collection(
client,
chroma_config.collection_examples,
chroma_config.embedding_model,
)
llm = setup_llm(chat_config.model_id, chat_config.max_new_tokens)
# For now, both chains share the same model to spare memory
answer_chain = setup_llm_chain(llm, chat_config.answer_template)
Expand Down Expand Up @@ -91,6 +96,15 @@ async def ask(question: str) -> Message:
return Message(text=answer, sender="AI", time=datetime.now())


@app.get("/examples/")
async def ask(question: str) -> Message:
"""Generate examples from question
and return examples to prompt."""
...
examples = generate_examples(question, collection_examples, sparql_chain)
return Message(text=examples, sender="AI", time=datetime.now())


@app.get("/sparql/")
async def sparql(question: str) -> Message:
"""Generate and return sparql query from question."""
Expand Down
35 changes: 33 additions & 2 deletions aikg/utils/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def post_process_answer(answer: str) -> str:


def generate_sparql(
question: str, collection: Collection, llm_chain: LLMChain, limit: int = 5
question: str,
collection: Collection,
llm_chain: LLMChain,
examples: str = "",
limit: int = 5,
) -> str:
"""Retrieve k-nearest documents from the vector store and synthesize
SPARQL query."""
Expand All @@ -67,10 +71,37 @@ def generate_sparql(
triples = "\n".join([res.get("triples", "") for res in results["metadatas"][0]])
# Convert to turtle for better readability and fewer tokens
triples = Graph().parse(data=triples).serialize(format="turtle")
query = llm_chain.run(question_str=question, context_str=triples)
query = llm_chain.run(
question_str=question, context_str=triples, examples_str=examples
)
return query


def generate_examples(
question: str,
collection: Collection,
limit: int = 5,
) -> str:
"""Retrieve k-nearest questions from the examples in the vector store and return them
together with their correponding query."""

# Retrieve documents and triples from top k subjects
examples = collection.query(query_texts=question, n_results=limit)
# Extract relevant information from dict
example_docs = examples["documents"][0]
example_meta = examples["metadatas"][0]
#
example_prompt = "Examples: \n\n"
for doc, meta in zip(example_docs, example_meta):
example_prompt += f"""
Question:
{doc}
Query:
{meta['query']}
"""
return example_prompt


def generate_answer(
question: str,
query: str,
Expand Down
6 changes: 3 additions & 3 deletions aikg/utils/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import chromadb
from chromadb.config import Settings
from chromadb.api import API, Collection
from chromadb.api import ClientAPI, Collection


def setup_client(host: str, port: int, persist_directory: str = ".chroma") -> API:
def setup_client(host: str, port: int, persist_directory: str = ".chroma") -> ClientAPI:
"""Prepare chromadb client. If host is 'local', chromadb will run in client-only mode."""
if host == "local":
chroma_client = chromadb.PersistentClient(path=persist_directory)
Expand All @@ -30,7 +30,7 @@ def setup_client(host: str, port: int, persist_directory: str = ".chroma") -> AP


def setup_collection(
client: API,
client: ClientAPI,
collection_name: str,
embedding_model: str,
) -> Collection:
Expand Down
27 changes: 26 additions & 1 deletion aikg/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# limitations under the License.

import requests
import os
from pathlib import Path

from typing import List, TextIO
from langchain.schema import Document
from tqdm import tqdm


Expand All @@ -31,3 +33,26 @@ def download_file(url: str, output_path: str | Path):
for chunk in tqdm(response.iter_content(chunk_size=8192)):
if chunk:
f.write(chunk)


def parse_sparql_example(example: TextIO) -> List[Document]:
"""
Parse a text stream as input with first line being a question (starting with #)
and the remaining lines being a (SPARQL) query. We reformat this content into a document
where the page content is the question and the query is attached as metadata
"""
# Create temp variable to process text stream
example_temp = []
example_temp.append(example.read())
# Splitting the file content into lines
lines = example_temp[0].split("\n")
# Extracting the question (removing '#' from the first line)
question = lines[0].strip()[1:]
# Extracting the SPARQL query from the remaining lines
sparql_query = "\n".join(lines[1:])
# Create example document for the output
example_doc = []
example_doc.append(
Document(page_content=question, metadata={"query": sparql_query})
)
return example_doc
4 changes: 0 additions & 4 deletions aikg/utils/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,6 @@ def get_subjects_docs(
)
docs = []

# Skipping header row
if isinstance(kg, SPARQLWrapper):
results = results[1:]

for sub, label, comment in results:
if comment is None:
comment = ""
Expand Down
Loading