Skip to content

Commit

Permalink
Non-streaming OpenAI Chat Completions
Browse files Browse the repository at this point in the history
  • Loading branch information
alpayariyak committed Jan 26, 2024
1 parent 4cebe66 commit 9fc8e1e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 40 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ You may either use a `prompt` or a list of `messages` as input. If you use `mess
|-----------------------|----------------------|--------------------|--------------------------------------------------------------------------------------------------------|
| `prompt` | str | | Prompt string to generate text based on. |
| `messages` | list[dict[str, str]] | | List of messages, which will automatically have the model's chat template applied. Overrides `prompt`. |
| `use_openai_format` | bool | False | Whether to return output in OpenAI format. `ALLOW_OPENAI_FORMAT` environment variable must be `1`, the input must be a `messages` list, and `stream` enabled. |
| `use_openai_format` | bool | False | Whether to return output in OpenAI format. `ALLOW_OPENAI_FORMAT` environment variable must be `1`, the input should preferably be a `messages` list, but `prompt` is accepted. |
| `apply_chat_template` | bool | False | Whether to apply the model's chat template to the `prompt`. |
| `sampling_params` | dict | {} | Sampling parameters to control the generation, like temperature, top_p, etc. |
| `stream` | bool | False | Whether to enable streaming of output. If True, responses are streamed as they are generated. |
Expand Down
80 changes: 42 additions & 38 deletions src/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from transformers import AutoTokenizer
from utils import count_physical_cores
from utils import count_physical_cores, DummyRequest
from constants import DEFAULT_MAX_CONCURRENCY
from dotenv import load_dotenv

Expand Down Expand Up @@ -106,55 +106,59 @@ async def generate_vllm(self, llm_input, validated_sampling_params, batch_size,

async def generate_openai_chat(self, llm_input, validated_sampling_params, batch_size, stream, apply_chat_template, request_id: str) -> AsyncGenerator[dict, None]:

if not isinstance(llm_input, list):
raise ValueError("Input must be a list of messages")

if not stream:
raise ValueError("OpenAI Chat Completion Format only supports streaming")
if isinstance(llm_input, str):
llm_input = [{"role": "user", "content": llm_input}]
logging.warning("OpenAI Chat Completion format requires list input, converting to list and assigning 'user' role")

if not self.openai_engine:
raise ValueError("OpenAI Chat Completion format is disabled")

chat_completion_request = ChatCompletionRequest(
model=self.config["model"],
messages=llm_input,
stream=True,
stream=stream,
**validated_sampling_params,
)

response_generator = await self.openai_engine.create_chat_completion(chat_completion_request, None) # None for raw_request
batch_contents = {}
batch_latest_choices = {}
batch_token_counter = 0
last_chunk = {}

async for chunk_str in response_generator:
try:
chunk = json.loads(chunk_str.removeprefix("data: ").rstrip("\n\n"))
except:
continue

if "choices" in chunk:
for choice in chunk["choices"]:
choice_index = choice["index"]
if "delta" in choice and "content" in choice["delta"]:
batch_contents[choice_index] = batch_contents.get(choice_index, []) + [choice["delta"]["content"]]
batch_latest_choices[choice_index] = choice
batch_token_counter += 1
last_chunk = chunk
response_generator = await self.openai_engine.create_chat_completion(chat_completion_request, DummyRequest())
if not stream:
yield json.loads(response_generator.model_dump_json())
else:
batch_contents = {}
batch_latest_choices = {}
batch_token_counter = 0
last_chunk = {}

if batch_token_counter >= batch_size:
async for chunk_str in response_generator:
try:
chunk = json.loads(chunk_str.removeprefix("data: ").rstrip("\n\n"))
except:
continue

if "choices" in chunk:
for choice in chunk["choices"]:
choice_index = choice["index"]
if "delta" in choice and "content" in choice["delta"]:
batch_contents[choice_index] = batch_contents.get(choice_index, []) + [choice["delta"]["content"]]
batch_latest_choices[choice_index] = choice
batch_token_counter += 1
last_chunk = chunk

if batch_token_counter >= batch_size:
for choice_index in batch_latest_choices:
batch_latest_choices[choice_index]["delta"]["content"] = batch_contents[choice_index]
last_chunk["choices"] = list(batch_latest_choices.values())
yield last_chunk

batch_contents = {}
batch_latest_choices = {}
batch_token_counter = 0

if batch_contents:
for choice_index in batch_latest_choices:
batch_latest_choices[choice_index]["delta"]["content"] = batch_contents[choice_index]
last_chunk["choices"] = list(batch_latest_choices.values())
yield last_chunk

batch_contents = {}
batch_latest_choices = {}
batch_token_counter = 0

if batch_contents:
for choice_index in batch_latest_choices:
batch_latest_choices[choice_index]["delta"]["content"] = batch_contents[choice_index]
last_chunk["choices"] = list(batch_latest_choices.values())
yield last_chunk

def _initialize_config(self):
quantization = self._get_quantization()
Expand Down
5 changes: 4 additions & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@ def __init__(self, job):
self.use_openai_format = job.get("use_openai_format", False)
self.validated_sampling_params = validate_sampling_params(job.get("sampling_params", {}))
self.request_id = random_uuid()


class DummyRequest:
async def is_disconnected(self):
return False

0 comments on commit 9fc8e1e

Please sign in to comment.