-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathquery.py
134 lines (106 loc) · 4.24 KB
/
query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
from dotenv import load_dotenv
from llama_index.llms.groq import Groq
import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import StorageContext
from llama_index.core import Settings
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.embeddings.cohere import CohereEmbedding
from langchain import hub
from llama_index.core import PromptTemplate
# langchain_prompt = hub.pull('rlm/rag-prompt')
# Load Tokens
load_dotenv()
GROQ = os.getenv("GROQ")
HF_TOKEN = os.getenv("HF_TOKEN")
cohere_api_key = os.getenv("COHERE_API_KEY")
llm_model = "llama-3.1-8b-instant"
def create_folders_and_file(folder_path, filename) -> str:
"""
Creates folders and subfolders if they don't exist and writes content to a file in the deepest folder.
Args:
folder_path (str): Path to the top-level folder.
filename (str): Name of the file to create in the deepest folder.
content (str, optional): Content to write to the file. Defaults to "This is some text".
"""
# Ensure path is a string
if not isinstance(folder_path, str):
raise TypeError("folder_path must be a string")
# Create folders using os.makedirs with exist_ok=True to handle existing directories
try:
os.makedirs(folder_path, exist_ok=True)
except OSError as e:
print(f"Error creating directories: {e}")
return
# Create the file with full path
full_path = os.path.join(folder_path, filename)
try:
with open(full_path, "w") as f:
pass
print(f"Successfully created file: {full_path}")
return full_path
except OSError as e:
print(f"Error creating file: {e}")
def generate_embeddings(
documents_path: str, server: str, embedding_path: str, channel: str
) -> None:
print("Generating embeddings...")
load_dotenv()
# Initialize embeddings
embeddings = CohereEmbedding(
api_key=cohere_api_key,
model_name="embed-english-light-v3.0",
input_type="search_query",
)
Settings.embed_model = embeddings
# Settings.embed_model = HuggingFaceEmbedding(
# model_name = 'nomic-ai/nomic-embed-text-v1'
# )
documents = SimpleDirectoryReader(documents_path).load_data()
for document in documents:
document.metadata = {"server": server[1:0], "channel": channel}
db = chromadb.PersistentClient(path=embedding_path)
# create collection
chroma_collection = db.get_or_create_collection(server)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# create your index
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
print("Done generating embeddings")
def query(prompt: str, server: str, embedding_path: str, channel: str) -> str:
model = "llama-3.1-8b-instant"
llm = Groq(model=model, api_key=GROQ)
Settings.llm = llm
# Initialize embeddings
embeddings = CohereEmbedding(
api_key=cohere_api_key,
model_name="embed-english-light-v3.0",
input_type="search_query",
)
Settings.embed_model = embeddings
# initialize client
db = chromadb.PersistentClient(path=embedding_path)
# get collection
chroma_collection = db.get_or_create_collection(server)
# assign chroma as the vector_store to the context
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
qa_prompt_tmpl = (
"The following is a Discord chat log.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context of the Discord chat log and not prior knowledge, "
"answer the query.\n"
"Query: {query_str}\n"
"Answer: "
)
qa_prompt = PromptTemplate(qa_prompt_tmpl)
# load your index from stored vectors
index = VectorStoreIndex.from_vector_store(
vector_store, storage_context=storage_context
)
query_engine = index.as_query_engine(summary_template=qa_prompt)
response = query_engine.query(f"query made from {channel}" + prompt)
return response