diff --git a/README.md b/README.md index 84a3be9..685e27a 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ You may either use a `prompt` or a list of `messages` as input. If you use `mess |-----------------------|----------------------|--------------------|--------------------------------------------------------------------------------------------------------| | `prompt` | str | | Prompt string to generate text based on. | | `messages` | list[dict[str, str]] | | List of messages, which will automatically have the model's chat template applied. Overrides `prompt`. | -| `use_openai_format` | bool | False | Whether to return output in OpenAI format. `ALLOW_OPENAI_FORMAT` environment variable must be `1`, the input must be a `messages` list, and `stream` enabled. | +| `use_openai_format` | bool | False | Whether to return output in OpenAI format. `ALLOW_OPENAI_FORMAT` environment variable must be `1`, the input should preferably be a `messages` list, but `prompt` is accepted. | | `apply_chat_template` | bool | False | Whether to apply the model's chat template to the `prompt`. | | `sampling_params` | dict | {} | Sampling parameters to control the generation, like temperature, top_p, etc. | | `stream` | bool | False | Whether to enable streaming of output. If True, responses are streamed as they are generated. | diff --git a/src/engine.py b/src/engine.py index b39c559..3f69139 100644 --- a/src/engine.py +++ b/src/engine.py @@ -7,7 +7,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.protocol import ChatCompletionRequest from transformers import AutoTokenizer -from utils import count_physical_cores +from utils import count_physical_cores, DummyRequest from constants import DEFAULT_MAX_CONCURRENCY from dotenv import load_dotenv @@ -106,55 +106,59 @@ async def generate_vllm(self, llm_input, validated_sampling_params, batch_size, async def generate_openai_chat(self, llm_input, validated_sampling_params, batch_size, stream, apply_chat_template, request_id: str) -> AsyncGenerator[dict, None]: - if not isinstance(llm_input, list): - raise ValueError("Input must be a list of messages") - - if not stream: - raise ValueError("OpenAI Chat Completion Format only supports streaming") + if isinstance(llm_input, str): + llm_input = [{"role": "user", "content": llm_input}] + logging.warning("OpenAI Chat Completion format requires list input, converting to list and assigning 'user' role") + + if not self.openai_engine: + raise ValueError("OpenAI Chat Completion format is disabled") chat_completion_request = ChatCompletionRequest( model=self.config["model"], messages=llm_input, - stream=True, + stream=stream, **validated_sampling_params, ) - response_generator = await self.openai_engine.create_chat_completion(chat_completion_request, None) # None for raw_request - batch_contents = {} - batch_latest_choices = {} - batch_token_counter = 0 - last_chunk = {} - - async for chunk_str in response_generator: - try: - chunk = json.loads(chunk_str.removeprefix("data: ").rstrip("\n\n")) - except: - continue - - if "choices" in chunk: - for choice in chunk["choices"]: - choice_index = choice["index"] - if "delta" in choice and "content" in choice["delta"]: - batch_contents[choice_index] = batch_contents.get(choice_index, []) + [choice["delta"]["content"]] - batch_latest_choices[choice_index] = choice - batch_token_counter += 1 - last_chunk = chunk + response_generator = await self.openai_engine.create_chat_completion(chat_completion_request, DummyRequest()) + if not stream: + yield json.loads(response_generator.model_dump_json()) + else: + batch_contents = {} + batch_latest_choices = {} + batch_token_counter = 0 + last_chunk = {} - if batch_token_counter >= batch_size: + async for chunk_str in response_generator: + try: + chunk = json.loads(chunk_str.removeprefix("data: ").rstrip("\n\n")) + except: + continue + + if "choices" in chunk: + for choice in chunk["choices"]: + choice_index = choice["index"] + if "delta" in choice and "content" in choice["delta"]: + batch_contents[choice_index] = batch_contents.get(choice_index, []) + [choice["delta"]["content"]] + batch_latest_choices[choice_index] = choice + batch_token_counter += 1 + last_chunk = chunk + + if batch_token_counter >= batch_size: + for choice_index in batch_latest_choices: + batch_latest_choices[choice_index]["delta"]["content"] = batch_contents[choice_index] + last_chunk["choices"] = list(batch_latest_choices.values()) + yield last_chunk + + batch_contents = {} + batch_latest_choices = {} + batch_token_counter = 0 + + if batch_contents: for choice_index in batch_latest_choices: batch_latest_choices[choice_index]["delta"]["content"] = batch_contents[choice_index] last_chunk["choices"] = list(batch_latest_choices.values()) yield last_chunk - - batch_contents = {} - batch_latest_choices = {} - batch_token_counter = 0 - - if batch_contents: - for choice_index in batch_latest_choices: - batch_latest_choices[choice_index]["delta"]["content"] = batch_contents[choice_index] - last_chunk["choices"] = list(batch_latest_choices.values()) - yield last_chunk def _initialize_config(self): quantization = self._get_quantization() diff --git a/src/utils.py b/src/utils.py index 4c81194..ff397c8 100644 --- a/src/utils.py +++ b/src/utils.py @@ -46,4 +46,7 @@ def __init__(self, job): self.use_openai_format = job.get("use_openai_format", False) self.validated_sampling_params = validate_sampling_params(job.get("sampling_params", {})) self.request_id = random_uuid() - \ No newline at end of file + +class DummyRequest: + async def is_disconnected(self): + return False \ No newline at end of file