-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
202 lines (159 loc) · 7.92 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# Necessary imports
import os
import requests
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException
from huggingface_hub import InferenceClient
# FastAPI app initialization
app = FastAPI()
# Request model for API endpoints
class QueryRequest(BaseModel):
ticket_id: str
# If we were to take inputs explicitly
# issue: str
# description: str
# category: str
# Take environment variables from the .env file
load_dotenv()
HF_TOKEN = os.environ["HF_TOKEN"]
client = InferenceClient(api_key=HF_TOKEN)
DATA_PATH = os.environ["DATA_PATH"]
NEW_TICKETS_PATH = os.path.join(DATA_PATH, "new_tickets.csv")
OLD_TICKETS_PATH = os.path.join(DATA_PATH, "old_tickets")
# Read the data from the files into DataFrames
try:
df_new_tickets = pd.read_csv(NEW_TICKETS_PATH)
df_old_tickets_1 = pd.read_csv(os.path.join(OLD_TICKETS_PATH, "ticket_dump_1.csv"))
df_old_tickets_2 = pd.read_excel(os.path.join(OLD_TICKETS_PATH, "ticket_dump_2.xlsx"))
df_old_tickets_3 = pd.read_json(os.path.join(OLD_TICKETS_PATH, "ticket_dump_3.json"))
df_old_tickets = pd.concat([df_old_tickets_1, df_old_tickets_2, df_old_tickets_3], ignore_index=True)
except FileNotFoundError:
raise HTTPException(status_code=500, detail="Data files not found. Please check the DATA_PATH environment variable and ensure the files exist.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error reading data files: {e}")
# Check if necessary columns exist
required_new_cols = ["Ticket ID", "Issue", "Description", "Category"]
required_old_cols = ["Ticket ID", "Issue", "Description", "Category", "Resolution"]
if not all(col in df_new_tickets.columns for col in required_new_cols):
missing_cols = [col for col in required_new_cols if col not in df_new_tickets.columns]
raise HTTPException(status_code=500, detail=f"Missing columns in new_tickets.csv: {missing_cols}")
if not all(col in df_old_tickets.columns for col in required_old_cols):
missing_cols = [col for col in required_old_cols if col not in df_old_tickets.columns]
raise HTTPException(status_code=500, detail=f"Missing columns in old_tickets data: {missing_cols}")
def merge_columns(row):
""" Create a new column in the given DF rows combining Issue, Description and Category columns."""
return f"Issue: {row['Issue']}. Description: {row['Description']} Category: {row['Category']}."
# Create the new columns to be used during the similarity search
df_new_tickets['Merged'] = df_new_tickets.apply(merge_columns, axis=1)
df_old_tickets['Merged'] = df_old_tickets.apply(merge_columns, axis=1)
def get_similarity_scores(query, docs):
""" Calculate the similarity scores of the given query with the existing documents.
all-MiniLM-L6-v2 is accessed via the Huggingface Inference API for this computation.
"""
api_url = "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2"
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
payload = {
"inputs": {
"source_sentence": query,
"sentences": docs
},
}
response = requests.post(api_url, headers=headers, json=payload)
return response.json()
def find_relevant_docs(scores, threshold=0.65):
""" Return the indices of the top three documents which is above the similarity threshold. """
scores = np.array(scores)
# Filter indices of values greater than the threshold
valid_indices = np.where(scores > threshold)[0]
# Sort these indices by their corresponding values in descending order
sorted_indices = valid_indices[np.argsort(scores[valid_indices])[::-1]]
# Get up to three indices
top_indices = sorted_indices[:3]
return top_indices
@app.post("/retrieve_docs")
def retrieve_docs(query_request: QueryRequest):
""" Given a query and a document collection, find and return the most relevant ones."""
# Prepare a list of old tickets to compare similarities
docs = df_old_tickets['Merged'].to_list()
# Find the requested ticket
ticket_id = query_request.ticket_id
# Check if ticket ID exists
if ticket_id not in df_new_tickets['Ticket ID'].values:
raise HTTPException(status_code=404, detail=f"Ticket ID {ticket_id} not found.")
# Get the query document
query_doc = df_new_tickets.loc[df_new_tickets['Ticket ID'] == ticket_id]["Merged"].values[0]
# Get the indices of the most relevant documents
scores = get_similarity_scores(query_doc, docs)
relevant_indices = find_relevant_docs(scores)
# Return the relevant documents as an array of JSON objects, excluding the last column
return df_old_tickets.iloc[relevant_indices, :-1].to_json(orient="records")
@app.post("/get_help")
def generate_response(query_request: QueryRequest):
""" Based on the given ticket and the retrieved information, generate a response using an LLM."""
# Find the requested ticket
ticket_id = query_request.ticket_id
# Check if ticket ID exists BEFORE trying to get query_doc
if ticket_id not in df_new_tickets['Ticket ID'].values:
raise HTTPException(status_code=404, detail=f"Ticket ID {ticket_id} not found.")
# Get the query document
query_doc = df_new_tickets.loc[df_new_tickets['Ticket ID'] == ticket_id]["Merged"].values[0]
# Retrieval
try:
retrieved_docs = retrieve_docs(query_request)
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error retrieving documents: {e}")
# LLM request
system_prompt = """You are a helpful AI assistant guiding IT helpdesk workers resolve open tickets.
Based on the given issue description by the user, a document retriever will retrieve you relevant tickets from the past that can help the users. You generate your answer based on these old tickets.
Use only the information and meta-data contained in the retrieved tickets.
DO NOT use your own knowledge, unless it is specifically asked.
Your goal is to give suggestions for resolving the ticket at hand based on the information provided to you.
The user do not know about the old tickets. So, always give reference at the beginning to the Ticket IDs that you looked at for generating your response.
If there tickets provided, format your answer in this way:
#### Ticket at Hand:
'Reprint the ticket given by the user'
#### Tracked Tickets:
'Briefly summarize the information provided to you based on older tickets' descriptions and resolutions. You can list them by their Ticket IDs'
#### Suggestions:
'Suggest a solution to the open ticket or person of contact based on what you were given.'
"""
user_prompt = f"""I need help with the following ticket:
### Ticket at Hand:
{query_doc}
### Retrieved Tickets:
{retrieved_docs}
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
completion = client.chat.completions.create(
# model="meta-llama/Llama-3.1-8B-Instruct", # Requires PRO subscription
model="meta-llama/Llama-3.2-3B-Instruct",
messages=messages,
max_tokens=256,
temperature=0,
stream=False
)
try:
completion = client.chat.completions.create(
# model="meta-llama/Llama-3.1-8B-Instruct", # Requires PRO subscription
model="meta-llama/Llama-3.2-3B-Instruct",
messages=messages,
max_tokens=256,
temperature=0,
stream=False
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during LLM completion: {e}")
# Return the generated response
return completion.choices[0].message.content
if __name__ == "__main__":
# Setting debug to True enables hot reload
# and also provides a debugger shell if you hit an error while running the server
app.run(debug=False)