Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
nlueem committed Aug 3, 2024
1 parent 2e8c19c commit 9110db1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.DS_Store
.DS_Store
.env
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
---
title: restful-Llama-3.1
emoji: 📈
colorFrom: purple
colorTo: yellow
title: Restful Llama3.1
emoji: 🔥
colorFrom: green
colorTo: indigo
sdk: docker
pinned: false
license: mit
Expand Down
38 changes: 22 additions & 16 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import os
import torch
from typing import Union

from transformers import pipeline
from transformers.utils import logging

from typing import Union
from pydantic import BaseModel

# FastAPI imports
from fastapi import Request,FastAPI
from fastapi.middleware.cors import CORSMiddleware

import torch

# Set up logging
logging.set_verbosity_info()
logging.enable_progress_bar()
logger = logging.get_logger("transformers")

app = FastAPI()
Expand All @@ -28,28 +29,28 @@
)

# detect host device for torch
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Device is {torch_device}")
TORCH_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Device is {TORCH_DEVICE}")

# set cache for Hugging Face
cache_dir = "./cache/"
os.environ["HF_HOME"] = cache_dir
CACHE_DIR = "./cache/"
os.environ["HF_HOME"] = CACHE_DIR

# Load token from env
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN")

# ToDo: add token
if not HUGGING_FACE_HUB_TOKEN:
raise ValueError("HUGGING_FACE_HUB_TOKEN secret not set!")

# Create model pipeline
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipe = pipeline(
"text-generation",
model=model_id,
model=MODEL_ID,
model_kwargs={"torch_dtype": torch.bfloat16},
device=torch_device,
batch_size=2
device=TORCH_DEVICE,
token=HUGGING_FACE_HUB_TOKEN,
batch_size=4
)

# Default when config values are not provided by the user.
Expand All @@ -60,7 +61,8 @@
}

# Default when no system prompt is provided by the user.
default_system_prompt = "You are a helpful assistant called Llama-3.1. Write out your answer short and succinct!"
default_system_prompt = """You are a helpful assistant called Llama-3.1.
Write out your funny and wrong answer in german!"""

logger.info("Model is loaded")

Expand All @@ -72,7 +74,10 @@ class ChatRequest(BaseModel):
max_new_tokens: Union[int, None] = None


def generate(messages: list, temperature: float = None, top_p: float = None, max_new_tokens: int = None) -> str:
def generate(messages: list,
temperature: float = None,
top_p: float = None,
max_new_tokens: int = None) -> str:
"""Generates a response given a list of messages (conversation history)
and the generation configuration."""

Expand Down Expand Up @@ -103,7 +108,8 @@ def generate(messages: list, temperature: float = None, top_p: float = None, max
text = generated_outputs[len(prompt):]
return text

def isSystemPrompt(msg):
def is_system_prompt(msg):
"""Check if a message is a system prompt."""
return msg["role"] == "system"

@app.post("/chat")
Expand All @@ -117,7 +123,7 @@ def chat(chat_request: ChatRequest):
top_p=chat_request.top_p
max_new_tokens=chat_request.max_new_tokens

if not isSystemPrompt(messages[0]):
if not is_system_prompt(messages[0]):
messages.insert(0, {"role": "system", "content": default_system_prompt})

logger.info("Generating response...")
Expand Down

0 comments on commit 9110db1

Please sign in to comment.