|
| 1 | +import asyncio |
1 | 2 | import re
|
2 | 3 | from functools import lru_cache
|
3 | 4 | from io import StringIO
|
|
6 | 7 | from typing import Any, Callable, Container, Dict, Generator, List, Optional, Type, Union
|
7 | 8 |
|
8 | 9 | import torch
|
| 10 | +from fastapi import Request |
9 | 11 | from huggingface_hub.utils import EntryNotFoundError
|
10 | 12 | from optimum.intel import OVModelForCausalLM, OVModelForSeq2SeqLM
|
11 | 13 | from transformers import (
|
@@ -61,11 +63,15 @@ def get_model(checkpoint: str, device: str = "CPU") -> OVModel:
|
61 | 63 | return model
|
62 | 64 |
|
63 | 65 |
|
| 66 | +# TODO: generator needs running flag or cancellation on new generation request |
| 67 | +# generator cannot handle concurrent requests - fails and stalls process |
| 68 | +# RuntimeError: Exception from src/inference/src/infer_request.cpp:189: |
| 69 | +# [ REQUEST_BUSY ] |
64 | 70 | class GeneratorFunctor:
|
65 | 71 | def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
|
66 | 72 | raise NotImplementedError
|
67 | 73 |
|
68 |
| - async def generate_stream(self, input_text: str, parameters: Dict[str, Any]): |
| 74 | + async def generate_stream(self, input_text: str, parameters: Dict[str, Any], request: Request): |
69 | 75 | raise NotImplementedError
|
70 | 76 |
|
71 | 77 | def summarize(self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]):
|
@@ -122,24 +128,45 @@ def __call__(
|
122 | 128 | logger.info(f"Number of input tokens: {prompt_len}; generated {len(output_ids)} tokens")
|
123 | 129 | return self.tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
124 | 130 |
|
125 |
| - async def generate_stream( |
126 |
| - self, input_text: str, parameters: Dict[str, Any], stopping_criteria: Optional[StoppingCriteriaList] = None |
127 |
| - ): |
| 131 | + async def generate_stream(self, input_text: str, parameters: Dict[str, Any], request: Request = None): |
128 | 132 | input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
|
129 | 133 | streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
130 | 134 | parameters["streamer"] = streamer
|
131 | 135 | config = GenerationConfig.from_dict({**self.generation_config.to_dict(), **parameters})
|
| 136 | + |
| 137 | + stop_on_tokens = StopOnTokens([]) |
| 138 | + |
132 | 139 | generation_kwargs = dict(
|
133 | 140 | input_ids=input_ids,
|
134 | 141 | streamer=streamer,
|
135 |
| - stopping_criteria=stopping_criteria, |
| 142 | + stopping_criteria=StoppingCriteriaList([stop_on_tokens]), |
136 | 143 | **config.to_dict(),
|
137 | 144 | )
|
| 145 | + |
| 146 | + # listen disconnect event so generation can be stopped |
| 147 | + def listen_for_disconnect(): |
| 148 | + async def listen(): |
| 149 | + message = await request.receive() |
| 150 | + if message.get("type") == "http.disconnect": |
| 151 | + stop_on_tokens.cancelled = True |
| 152 | + asyncio.create_task(listen()) |
| 153 | + |
| 154 | + |
| 155 | + listen_thread = Thread(target=listen_for_disconnect) |
| 156 | + # thread.run doesn't actually start a new thread |
| 157 | + # it runs the thread function in current thread context |
| 158 | + # thread.start() doesn't work here |
| 159 | + listen_thread.run() |
| 160 | + |
138 | 161 | thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
139 | 162 | thread.start()
|
| 163 | + |
140 | 164 | for token in streamer:
|
| 165 | + await asyncio.sleep(0.01) |
141 | 166 | yield token
|
142 | 167 |
|
| 168 | + thread.join() |
| 169 | + |
143 | 170 | def generate_between(
|
144 | 171 | self,
|
145 | 172 | input_parts: List[str],
|
@@ -243,7 +270,10 @@ def inner() -> GeneratorFunctor:
|
243 | 270 |
|
244 | 271 | class StopOnTokens(StoppingCriteria):
|
245 | 272 | def __init__(self, token_ids: List[int]) -> None:
|
| 273 | + self.cancelled = False |
246 | 274 | self.token_ids = torch.tensor(token_ids, requires_grad=False)
|
247 | 275 |
|
248 | 276 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 277 | + if self.cancelled: |
| 278 | + return True |
249 | 279 | return torch.any(torch.eq(input_ids[0, -1], self.token_ids)).item()
|
0 commit comments