Skip to content

Commit a0730fa

Browse files
committed
black reformatted files
1 parent 3f31661 commit a0730fa

File tree

3 files changed

+158
-40
lines changed

3 files changed

+158
-40
lines changed

modules/openvino_code/server/src/app.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ class GenerationDocStringRequest(BaseModel):
5454
description="Doc string format passed from extension settings [google | numpy | sphinx | dockblockr | ...]",
5555
example="numpy",
5656
)
57-
definition: str = Field("", description="Function signature", example="def fibonacci(n):")
57+
definition: str = Field(
58+
"", description="Function signature", example="def fibonacci(n):"
59+
)
5860
parameters: GenerationParameters
5961

6062

@@ -111,10 +113,16 @@ async def generate_stream(
111113
request: Request,
112114
generator: GeneratorFunctor = Depends(get_generator_dummy),
113115
) -> StreamingResponse:
114-
generation_request = TypeAdapter(GenerationRequest).validate_python(await request.json())
116+
generation_request = TypeAdapter(GenerationRequest).validate_python(
117+
await request.json()
118+
)
115119
logger.info(generation_request)
116120
return StreamingResponse(
117-
generator.generate_stream(generation_request.inputs, generation_request.parameters.model_dump(), request)
121+
generator.generate_stream(
122+
generation_request.inputs,
123+
generation_request.parameters.model_dump(),
124+
request,
125+
)
118126
)
119127

120128

@@ -127,7 +135,11 @@ async def summarize(
127135

128136
start = perf_counter()
129137
generated_text: str = generator.summarize(
130-
request.inputs, request.template, request.definition, request.format, request.parameters.model_dump()
138+
request.inputs,
139+
request.template,
140+
request.definition,
141+
request.format,
142+
request.parameters.model_dump(),
131143
)
132144
stop = perf_counter()
133145

@@ -148,6 +160,10 @@ async def summarize_stream(
148160
logger.info(request)
149161
return StreamingResponse(
150162
generator.summarize_stream(
151-
request.inputs, request.template, request.definition, request.format, request.parameters.model_dump()
163+
request.inputs,
164+
request.template,
165+
request.definition,
166+
request.format,
167+
request.parameters.model_dump(),
152168
)
153169
)

modules/openvino_code/server/src/generators.py

+131-33
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@
55
from pathlib import Path
66
from threading import Thread
77
from time import time
8-
from typing import Any, Callable, Container, Dict, Generator, List, Optional, Type, Union
8+
from typing import (
9+
Any,
10+
Callable,
11+
Container,
12+
Dict,
13+
Generator,
14+
List,
15+
Optional,
16+
Type,
17+
Union,
18+
)
919

1020
import torch
1121
from fastapi import Request
@@ -30,14 +40,18 @@
3040
model_dir = Path("models")
3141
model_dir.mkdir(exist_ok=True)
3242

33-
SUMMARIZE_INSTRUCTION = "{function}\n\n# The function with {style} style docstring\n\n{signature}\n"
43+
SUMMARIZE_INSTRUCTION = (
44+
"{function}\n\n# The function with {style} style docstring\n\n{signature}\n"
45+
)
3446
SUMMARIZE_STOP_TOKENS = ("\r\n", "\n")
3547

3648

3749
def get_model_class(checkpoint: Union[str, Path]) -> Type[OVModel]:
3850
config = AutoConfig.from_pretrained(checkpoint)
3951
architecture: str = config.architectures[0]
40-
if architecture.endswith("ConditionalGeneration") or architecture.endswith("Seq2SeqLM"):
52+
if architecture.endswith("ConditionalGeneration") or architecture.endswith(
53+
"Seq2SeqLM"
54+
):
4155
return OVModelForSeq2SeqLM
4256

4357
return OVModelForCausalLM
@@ -48,16 +62,27 @@ def get_model(checkpoint: str, device: str = "CPU") -> OVModel:
4862
model_path = model_dir / Path(checkpoint)
4963
if model_path.exists():
5064
model_class = get_model_class(model_path)
51-
model = model_class.from_pretrained(model_path, ov_config=ov_config, compile=False, device=device)
65+
model = model_class.from_pretrained(
66+
model_path, ov_config=ov_config, compile=False, device=device
67+
)
5268
else:
5369
model_class = get_model_class(checkpoint)
5470
try:
5571
model = model_class.from_pretrained(
56-
checkpoint, ov_config=ov_config, compile=False, device=device, trust_remote_code=True
72+
checkpoint,
73+
ov_config=ov_config,
74+
compile=False,
75+
device=device,
76+
trust_remote_code=True,
5777
)
5878
except EntryNotFoundError:
5979
model = model_class.from_pretrained(
60-
checkpoint, ov_config=ov_config, export=True, compile=False, device=device, trust_remote_code=True
80+
checkpoint,
81+
ov_config=ov_config,
82+
export=True,
83+
compile=False,
84+
device=device,
85+
trust_remote_code=True,
6186
)
6287
model.save_pretrained(model_path)
6388
model.compile()
@@ -72,13 +97,29 @@ class GeneratorFunctor:
7297
def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
7398
raise NotImplementedError
7499

75-
async def generate_stream(self, input_text: str, parameters: Dict[str, Any], request: Request):
100+
async def generate_stream(
101+
self, input_text: str, parameters: Dict[str, Any], request: Request
102+
):
76103
raise NotImplementedError
77104

78-
def summarize(self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]):
105+
def summarize(
106+
self,
107+
input_text: str,
108+
template: str,
109+
signature: str,
110+
style: str,
111+
parameters: Dict[str, Any],
112+
):
79113
raise NotImplementedError
80114

81-
def summarize_stream(self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]):
115+
def summarize_stream(
116+
self,
117+
input_text: str,
118+
template: str,
119+
signature: str,
120+
style: str,
121+
parameters: Dict[str, Any],
122+
):
82123
raise NotImplementedError
83124

84125

@@ -113,9 +154,14 @@ def __init__(
113154
if summarize_stop_tokens:
114155
stop_tokens = []
115156
for token_id in self.tokenizer.vocab.values():
116-
if any(stop_word in self.tokenizer.decode(token_id) for stop_word in summarize_stop_tokens):
157+
if any(
158+
stop_word in self.tokenizer.decode(token_id)
159+
for stop_word in summarize_stop_tokens
160+
):
117161
stop_tokens.append(token_id)
118-
self.summarize_stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_tokens)])
162+
self.summarize_stopping_criteria = StoppingCriteriaList(
163+
[StopOnTokens(stop_tokens)]
164+
)
119165

120166
def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
121167
input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
@@ -126,20 +172,36 @@ def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
126172
stopping_criteria = StoppingCriteriaList([stop_on_time])
127173

128174
prompt_len = input_ids.shape[-1]
129-
config = GenerationConfig.from_dict({**self.generation_config.to_dict(), **parameters})
175+
config = GenerationConfig.from_dict(
176+
{**self.generation_config.to_dict(), **parameters}
177+
)
130178
output_ids = self.model.generate(
131-
input_ids, generation_config=config, stopping_criteria=stopping_criteria, **self.assistant_model_config
179+
input_ids,
180+
generation_config=config,
181+
stopping_criteria=stopping_criteria,
182+
**self.assistant_model_config,
132183
)[0][prompt_len:]
133-
logger.info(f"Number of input tokens: {prompt_len}; generated {len(output_ids)} tokens")
134-
return self.tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
184+
logger.info(
185+
f"Number of input tokens: {prompt_len}; generated {len(output_ids)} tokens"
186+
)
187+
return self.tokenizer.decode(
188+
output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
189+
)
135190

136191
async def generate_stream(
137-
self, input_text: str, parameters: Dict[str, Any], request: Optional[Request] = None
192+
self,
193+
input_text: str,
194+
parameters: Dict[str, Any],
195+
request: Optional[Request] = None,
138196
) -> Generator[str, None, None]:
139197
input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
140-
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
198+
streamer = TextIteratorStreamer(
199+
self.tokenizer, skip_prompt=True, skip_special_tokens=True
200+
)
141201
parameters["streamer"] = streamer
142-
config = GenerationConfig.from_dict({**self.generation_config.to_dict(), **parameters})
202+
config = GenerationConfig.from_dict(
203+
{**self.generation_config.to_dict(), **parameters}
204+
)
143205

144206
stop_on_tokens = StopOnTokens([])
145207

@@ -180,7 +242,9 @@ def generate_between(
180242
parameters: Dict[str, Any],
181243
stopping_criteria: Optional[StoppingCriteriaList] = None,
182244
) -> str:
183-
config = GenerationConfig.from_dict({**self.generation_config.to_dict(), **parameters})
245+
config = GenerationConfig.from_dict(
246+
{**self.generation_config.to_dict(), **parameters}
247+
)
184248

185249
prompt = torch.tensor([[]], dtype=torch.int64)
186250
buffer = StringIO()
@@ -192,13 +256,20 @@ def generate_between(
192256
prev_len = prompt.shape[-1]
193257

194258
prompt = self.model.generate(
195-
prompt, generation_config=config, stopping_criteria=stopping_criteria, **self.assistant_model_config
259+
prompt,
260+
generation_config=config,
261+
stopping_criteria=stopping_criteria,
262+
**self.assistant_model_config,
196263
)[
197264
:, :-1
198265
] # skip the last token - stop token
199266

200-
decoded = self.tokenizer.decode(prompt[0, prev_len:], skip_special_tokens=True)
201-
buffer.write(decoded.lstrip(" ")) # hack to delete leadding spaces if there are any
267+
decoded = self.tokenizer.decode(
268+
prompt[0, prev_len:], skip_special_tokens=True
269+
)
270+
buffer.write(
271+
decoded.lstrip(" ")
272+
) # hack to delete leadding spaces if there are any
202273
buffer.write(input_parts[-1])
203274
return buffer.getvalue()
204275

@@ -208,7 +279,9 @@ async def generate_between_stream(
208279
parameters: Dict[str, Any],
209280
stopping_criteria: Optional[StoppingCriteriaList] = None,
210281
) -> Generator[str, None, None]:
211-
config = GenerationConfig.from_dict({**self.generation_config.to_dict(), **parameters})
282+
config = GenerationConfig.from_dict(
283+
{**self.generation_config.to_dict(), **parameters}
284+
)
212285

213286
prompt = self.tokenizer.encode(input_parts[0], return_tensors="pt")
214287
for text_input in input_parts[1:-1]:
@@ -219,12 +292,17 @@ async def generate_between_stream(
219292
prev_len = prompt.shape[-1]
220293

221294
prompt = self.model.generate(
222-
prompt, generation_config=config, stopping_criteria=stopping_criteria, **self.assistant_model_config
295+
prompt,
296+
generation_config=config,
297+
stopping_criteria=stopping_criteria,
298+
**self.assistant_model_config,
223299
)[
224300
:, :-1
225301
] # skip the last token - stop token
226302

227-
decoded = self.tokenizer.decode(prompt[0, prev_len:], skip_special_tokens=True)
303+
decoded = self.tokenizer.decode(
304+
prompt[0, prev_len:], skip_special_tokens=True
305+
)
228306
yield decoded.lstrip(" ") # hack to delete leadding spaces if there are any
229307

230308
yield input_parts[-1]
@@ -237,24 +315,40 @@ def summarization_input(function: str, signature: str, style: str) -> str:
237315
signature=signature,
238316
)
239317

240-
def summarize(self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]) -> str:
318+
def summarize(
319+
self,
320+
input_text: str,
321+
template: str,
322+
signature: str,
323+
style: str,
324+
parameters: Dict[str, Any],
325+
) -> str:
241326
prompt = self.summarization_input(input_text, signature, style)
242327
splited_template = re.split(r"\$\{.*\}", template)
243328
splited_template[0] = prompt + splited_template[0]
244329

245-
return self.generate_between(splited_template, parameters, stopping_criteria=self.summarize_stopping_criteria)[
246-
len(prompt) :
247-
]
330+
return self.generate_between(
331+
splited_template,
332+
parameters,
333+
stopping_criteria=self.summarize_stopping_criteria,
334+
)[len(prompt) :]
248335

249336
async def summarize_stream(
250-
self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]
337+
self,
338+
input_text: str,
339+
template: str,
340+
signature: str,
341+
style: str,
342+
parameters: Dict[str, Any],
251343
):
252344
prompt = self.summarization_input(input_text, signature, style)
253345
splited_template = re.split(r"\$\{.*\}", template)
254346
splited_template = [prompt] + splited_template
255347

256348
async for token in self.generate_between_stream(
257-
splited_template, parameters, stopping_criteria=self.summarize_stopping_criteria
349+
splited_template,
350+
parameters,
351+
stopping_criteria=self.summarize_stopping_criteria,
258352
):
259353
yield token
260354

@@ -279,7 +373,9 @@ def __init__(self, token_ids: List[int]) -> None:
279373
self.cancelled = False
280374
self.token_ids = torch.tensor(token_ids, requires_grad=False)
281375

282-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
376+
def __call__(
377+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
378+
) -> bool:
283379
if self.cancelled:
284380
return True
285381
return torch.any(torch.eq(input_ids[0, -1], self.token_ids)).item()
@@ -304,4 +400,6 @@ def __call__(self, *args, **kwargs) -> bool:
304400
self.time_for_prev_token = elapsed
305401
self.time = current_time
306402

307-
return self.stop_until < current_time + self.time_for_prev_token * self.grow_factor
403+
return (
404+
self.stop_until < current_time + self.time_for_prev_token * self.grow_factor
405+
)

modules/openvino_code/server/src/utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ def get_parser() -> argparse.ArgumentParser:
99
parser.add_argument("--port", type=int, default="8000")
1010

1111
parser.add_argument("--model", type=str, required=True)
12-
parser.add_argument("--tokenizer_checkpoint", type=str, required=False, default=None)
12+
parser.add_argument(
13+
"--tokenizer_checkpoint", type=str, required=False, default=None
14+
)
1315
parser.add_argument("--device", type=str, required=False, default="CPU")
1416
parser.add_argument("--assistant", type=str, required=False, default=None)
1517
parser.add_argument("--summarization-endpoint", action="store_true")
@@ -34,7 +36,9 @@ def get_logger(
3436
class ServerLogger(logging.Logger):
3537
_server_log_prefix = "[OpenVINO Code Server Log]"
3638

37-
default_formatter = logging.Formatter(f"{_server_log_prefix} %(asctime)s %(levelname)s %(message)s")
39+
default_formatter = logging.Formatter(
40+
f"{_server_log_prefix} %(asctime)s %(levelname)s %(message)s"
41+
)
3842

3943
def __init__(self, name):
4044
super(ServerLogger, self).__init__(name)

0 commit comments

Comments
 (0)