Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
nlueem committed Aug 3, 2024
1 parent 9110db1 commit 4a2ea81
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@
from pydantic import BaseModel

# FastAPI imports
from fastapi import Request,FastAPI
from fastapi import Request, FastAPI
from fastapi.middleware.cors import CORSMiddleware

import torch

"""
This module sets up a FastAPI server to interact with a text-generation model.
It uses Hugging Face transformers, Pydantic for request validation, and Torch for device management.
"""

# Set up logging
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
Expand All @@ -28,11 +33,11 @@
allow_headers=["*"],
)

# detect host device for torch
# Detect host device for torch
TORCH_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Device is {TORCH_DEVICE}")

# set cache for Hugging Face
# Set cache for Hugging Face
CACHE_DIR = "./cache/"
os.environ["HF_HOME"] = CACHE_DIR

Expand All @@ -55,18 +60,18 @@

# Default when config values are not provided by the user.
default_generation_config = {
"temperature": 0.2, #0.2
"temperature": 0.2,
"top_p": 0.9,
"max_new_tokens": 256, #128
"max_new_tokens": 256,
}

# Default when no system prompt is provided by the user.
default_system_prompt = """You are a helpful assistant called Llama-3.1.
Write out your funny and wrong answer in german!"""
DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant called Llama-3.1.
Write out your funny and wrong answer in German!"""

logger.info("Model is loaded")

# Data model for making POST requests to /chat
# Data model for making POST requests to /chat
class ChatRequest(BaseModel):
messages: list
temperature: Union[float, None] = None
Expand Down Expand Up @@ -119,12 +124,12 @@ def chat(chat_request: ChatRequest):
Providing an initial system prompt in the messages is also optional."""

messages = chat_request.messages
temperature=chat_request.temperature
top_p=chat_request.top_p
max_new_tokens=chat_request.max_new_tokens
temperature = chat_request.temperature
top_p = chat_request.top_p
max_new_tokens = chat_request.max_new_tokens

if not is_system_prompt(messages[0]):
messages.insert(0, {"role": "system", "content": default_system_prompt})
messages.insert(0, {"role": "system", "content": DEFAULT_SYSTEM_PROMPT})

logger.info("Generating response...")
response = generate(messages, temperature, top_p, max_new_tokens)
Expand All @@ -133,4 +138,4 @@ def chat(chat_request: ChatRequest):

if __name__ == "__main__":
# Setting debug to True enables hot reload and provides a debugger shell if you hit an error while running the server
app.run(debug=False)
app.run(debug=False)

0 comments on commit 4a2ea81

Please sign in to comment.