-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
135 lines (108 loc) · 4.13 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
from transformers import pipeline
from typing import Union
# FastAPI imports
from fastapi import Request,FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print (f"##### Device is {torch_device} #####")
# Create model pipeline
pipe = pipeline(
"text-generation",
model="./data/models/Meta-Llama-3-8B-Instruct",
# model="meta-llama/Meta-Llama-3-8B-Instruct", # if you want to download the model from HF, login required
model_kwargs={"torch_dtype": torch.bfloat16},
device=torch_device,
)
# Needed when config values are not provided by the user.
default_generation_config = {
"temperature": 0.2, #0.2
"top_p": 0.9,
"max_new_tokens": 256, #128
}
# Needed when no system prompt is provided by the user.
default_system_prompt = "You are a helpful assistant called Llama-3. Write out your answer short and succinct!"
# model.eval()
# if torch.__version__ >= "2":
# model = torch.compile(model)
print("##### Model is loaded #####")
# Data model for making POST requests to /chat
class ChatRequest(BaseModel):
messages: list
temperature: Union[float, None] = None
top_p: Union[float, None] = None
max_new_tokens: Union[int, None] = None
def generate(messages: list, temperature: float = None, top_p: float = None, max_new_tokens: int = None) -> str:
"""Generates a response given a list of messages (conversation history) and the generation configuration."""
temperature = temperature if temperature else default_generation_config["temperature"]
top_p = top_p if top_p else default_generation_config["top_p"]
max_new_tokens = max_new_tokens if max_new_tokens else default_generation_config["max_new_tokens"]
prompt = pipe.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
terminators = [
pipe.tokenizer.eos_token_id,
pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = pipe(
prompt,
max_new_tokens=max_new_tokens,
eos_token_id=terminators,
do_sample=False,
temperature=temperature,
top_p=top_p,
)
generated_outputs = outputs[0]["generated_text"] # full prompt
text = generated_outputs[len(prompt):] # just the response
return text
def isSystemPrompt(msg):
if msg["role"] == "system":
return True
else:
return False
@app.get("/home")
def home():
"""Mainly for quick testing. If the service is running, you should see a full welcome message generated by the model."""
input_text = "Write a welcome message for the home page of a large language model chatbot"
messages = [
{"role": "system", "content": default_system_prompt},
{"role": "user", "content": input_text},
]
print("##### Generating welcome response #####")
response = generate(messages)
welcome_text = "<h2>Welcome to LLM service!</h2>"
welcome_text += response
return welcome_text
@app.post("/chat")
def chat(chat_request: ChatRequest):
"""The main endpoint for interacting with the model.
A list of messages is required, but the other config parameters can be left empty.
Providing an initial system prompt in the messages is also optional."""
messages = chat_request.messages
temperature=chat_request.temperature
top_p=chat_request.top_p
max_new_tokens=chat_request.max_new_tokens
# check system prompt, add one if necessary
if not isSystemPrompt(messages[0]):
msg = {"role": "system", "content": default_system_prompt}
messages.insert(0, msg)
print("##### Generating response... #####")
response = generate(messages, temperature, top_p, max_new_tokens)
return response
if __name__ == "__main__":
# setting debug to True enables hot reload
# and also provides a debugger shell
# if you hit an error while running the server
app.run(debug=False)