-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
114 lines (96 loc) · 3.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from dotenv import load_dotenv
from langchain_pinecone import PineconeVectorStore
from langgraph.graph import MessagesState, StateGraph, END
from langgraph.prebuilt import tools_condition
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import ToolNode
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_chroma import Chroma
from langgraph.checkpoint.memory import MemorySaver
load_dotenv()
graph_builder = StateGraph(MessagesState)
llm = ChatOllama(model="llama3.1:8b")
embeddings = OllamaEmbeddings(model="chroma/all-minilm-l6-v2-f32")
vector_store = Chroma(embedding_function=embeddings, persist_directory=os.environ['CHROMA_PATH'])
# vector_store = PineconeVectorStore(
# index_name=os.environ['INDEX_NAME'], embedding=embeddings
# )
# new_vector_store = FAISS.load_local(
# "faiss_index",
# embeddings=embeddings,
# allow_dangerous_deserialization=True,
# )
@tool(response_format="content_and_artifact")
def retrieve(query: str):
"""Retrieve information related to a query"""
print("query=", query)
retrieved_docs = vector_store.similarity_search(query,
k=2,
filter={"title":"Departments"})
print("retrieved docs=", retrieved_docs)
serialized = "\n\n".join(
f"Source: {doc.metadata}\n" f"Content: {doc.page_content}"
for doc in retrieved_docs
)
print("serialized=", serialized)
return serialized, retrieved_docs
# Step 1: Generate an AIMessage that may include a tool-call to be sent.
def query_or_respond(state: MessagesState):
"""Generate tool call for retrieval or respond"""
llm_with_tools = llm.bind_tools([retrieve])
response = llm_with_tools.invoke(state["messages"])
print("response=", response)
# MessagesState appends messages to state instead of overwriting
return {"messages": [response]}
# Step 2: Execute the retrieval.
tools = ToolNode([retrieve])
# Step 3: Generate a response using the retrieved content
def generate(state: MessagesState):
"""Generate answer"""
# Get generated ToolMessages
recent_tool_messages = []
for message in reversed(state["messages"]):
if message.type == "tool":
recent_tool_messages.append(message)
else:
break
tool_messages = recent_tool_messages[::-1]
# Format into prompt
docs_content = "\n\n".join(doc.content for doc in tool_messages)
system_message_content = (
"You are the receptionist for Mar Baselios College of Engineering and Technology"
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise."
"\n\n"
f"{docs_content}"
)
conversation_messages = [
message
for message in state["messages"]
if message.type in ("human", "system")
or (message.type == "ai" and not message.tool_calls)
]
prompt = [SystemMessage(system_message_content)] + conversation_messages
# Run
response = llm.invoke(prompt)
return {"messages": [response]}
graph_builder.add_node(query_or_respond)
graph_builder.add_node(tools)
graph_builder.add_node(generate)
graph_builder.set_entry_point("query_or_respond")
graph_builder.add_conditional_edges(
"query_or_respond",
tools_condition,
{END: END, "tools": "tools"},
)
graph_builder.add_edge("tools", "generate")
graph_builder.add_edge("generate", END)
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)