Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
nlueem committed Aug 3, 2024
1 parent 4a2ea81 commit 4c0441c
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
This module sets up a FastAPI server to interact with a text-generation model.
It uses Hugging Face transformers, Pydantic for request validation, and Torch for device management.
"""

import os
from typing import Union

Expand All @@ -12,11 +17,6 @@

import torch

"""
This module sets up a FastAPI server to interact with a text-generation model.
It uses Hugging Face transformers, Pydantic for request validation, and Torch for device management.
"""

# Set up logging
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
Expand Down Expand Up @@ -73,6 +73,7 @@

# Data model for making POST requests to /chat
class ChatRequest(BaseModel):
"""Class representing a chat-request"""
messages: list
temperature: Union[float, None] = None
top_p: Union[float, None] = None
Expand All @@ -86,10 +87,18 @@ def generate(messages: list,
"""Generates a response given a list of messages (conversation history)
and the generation configuration."""

temperature = temperature if temperature else default_generation_config["temperature"]
top_p = top_p if top_p else default_generation_config["top_p"]
max_new_tokens = max_new_tokens if max_new_tokens else default_generation_config["max_new_tokens"]

temperature = (
temperature if temperature is not None
else default_generation_config["temperature"]
)
top_p = (
top_p if top_p is not None
else default_generation_config["top_p"]
)
max_new_tokens = (
max_new_tokens if max_new_tokens is not None
else default_generation_config["max_new_tokens"]
)
prompt = pipe.tokenizer.apply_chat_template(
messages,
tokenize=False,
Expand Down Expand Up @@ -137,5 +146,6 @@ def chat(chat_request: ChatRequest):
return response

if __name__ == "__main__":
# Setting debug to True enables hot reload and provides a debugger shell if you hit an error while running the server
# Setting debug to True enables hot reload and provides a debugger shell
# if you hit an error while running the server
app.run(debug=False)

0 comments on commit 4c0441c

Please sign in to comment.