5
5
from pathlib import Path
6
6
from threading import Thread
7
7
from time import time
8
- from typing import Any , Callable , Container , Dict , Generator , List , Optional , Type , Union
8
+ from typing import (
9
+ Any ,
10
+ Callable ,
11
+ Container ,
12
+ Dict ,
13
+ Generator ,
14
+ List ,
15
+ Optional ,
16
+ Type ,
17
+ Union ,
18
+ )
9
19
10
20
import torch
11
21
from fastapi import Request
@@ -53,11 +63,20 @@ def get_model(checkpoint: str, device: str = "CPU") -> OVModel:
53
63
model_class = get_model_class (checkpoint )
54
64
try :
55
65
model = model_class .from_pretrained (
56
- checkpoint , ov_config = ov_config , compile = False , device = device , trust_remote_code = True
66
+ checkpoint ,
67
+ ov_config = ov_config ,
68
+ compile = False ,
69
+ device = device ,
70
+ trust_remote_code = True ,
57
71
)
58
72
except EntryNotFoundError :
59
73
model = model_class .from_pretrained (
60
- checkpoint , ov_config = ov_config , export = True , compile = False , device = device , trust_remote_code = True
74
+ checkpoint ,
75
+ ov_config = ov_config ,
76
+ export = True ,
77
+ compile = False ,
78
+ device = device ,
79
+ trust_remote_code = True ,
61
80
)
62
81
model .save_pretrained (model_path )
63
82
model .compile ()
@@ -75,10 +94,24 @@ def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
75
94
async def generate_stream (self , input_text : str , parameters : Dict [str , Any ], request : Request ):
76
95
raise NotImplementedError
77
96
78
- def summarize (self , input_text : str , template : str , signature : str , style : str , parameters : Dict [str , Any ]):
97
+ def summarize (
98
+ self ,
99
+ input_text : str ,
100
+ template : str ,
101
+ signature : str ,
102
+ style : str ,
103
+ parameters : Dict [str , Any ],
104
+ ):
79
105
raise NotImplementedError
80
106
81
- def summarize_stream (self , input_text : str , template : str , signature : str , style : str , parameters : Dict [str , Any ]):
107
+ def summarize_stream (
108
+ self ,
109
+ input_text : str ,
110
+ template : str ,
111
+ signature : str ,
112
+ style : str ,
113
+ parameters : Dict [str , Any ],
114
+ ):
82
115
raise NotImplementedError
83
116
84
117
@@ -128,13 +161,19 @@ def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
128
161
prompt_len = input_ids .shape [- 1 ]
129
162
config = GenerationConfig .from_dict ({** self .generation_config .to_dict (), ** parameters })
130
163
output_ids = self .model .generate (
131
- input_ids , generation_config = config , stopping_criteria = stopping_criteria , ** self .assistant_model_config
164
+ input_ids ,
165
+ generation_config = config ,
166
+ stopping_criteria = stopping_criteria ,
167
+ ** self .assistant_model_config ,
132
168
)[0 ][prompt_len :]
133
169
logger .info (f"Number of input tokens: { prompt_len } ; generated { len (output_ids )} tokens" )
134
170
return self .tokenizer .decode (output_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False )
135
171
136
172
async def generate_stream (
137
- self , input_text : str , parameters : Dict [str , Any ], request : Optional [Request ] = None
173
+ self ,
174
+ input_text : str ,
175
+ parameters : Dict [str , Any ],
176
+ request : Optional [Request ] = None ,
138
177
) -> Generator [str , None , None ]:
139
178
input_ids = self .tokenizer .encode (input_text , return_tensors = "pt" )
140
179
streamer = TextIteratorStreamer (self .tokenizer , skip_prompt = True , skip_special_tokens = True )
@@ -192,7 +231,10 @@ def generate_between(
192
231
prev_len = prompt .shape [- 1 ]
193
232
194
233
prompt = self .model .generate (
195
- prompt , generation_config = config , stopping_criteria = stopping_criteria , ** self .assistant_model_config
234
+ prompt ,
235
+ generation_config = config ,
236
+ stopping_criteria = stopping_criteria ,
237
+ ** self .assistant_model_config ,
196
238
)[
197
239
:, :- 1
198
240
] # skip the last token - stop token
@@ -219,7 +261,10 @@ async def generate_between_stream(
219
261
prev_len = prompt .shape [- 1 ]
220
262
221
263
prompt = self .model .generate (
222
- prompt , generation_config = config , stopping_criteria = stopping_criteria , ** self .assistant_model_config
264
+ prompt ,
265
+ generation_config = config ,
266
+ stopping_criteria = stopping_criteria ,
267
+ ** self .assistant_model_config ,
223
268
)[
224
269
:, :- 1
225
270
] # skip the last token - stop token
@@ -237,24 +282,40 @@ def summarization_input(function: str, signature: str, style: str) -> str:
237
282
signature = signature ,
238
283
)
239
284
240
- def summarize (self , input_text : str , template : str , signature : str , style : str , parameters : Dict [str , Any ]) -> str :
285
+ def summarize (
286
+ self ,
287
+ input_text : str ,
288
+ template : str ,
289
+ signature : str ,
290
+ style : str ,
291
+ parameters : Dict [str , Any ],
292
+ ) -> str :
241
293
prompt = self .summarization_input (input_text , signature , style )
242
294
splited_template = re .split (r"\$\{.*\}" , template )
243
295
splited_template [0 ] = prompt + splited_template [0 ]
244
296
245
- return self .generate_between (splited_template , parameters , stopping_criteria = self .summarize_stopping_criteria )[
246
- len (prompt ) :
247
- ]
297
+ return self .generate_between (
298
+ splited_template ,
299
+ parameters ,
300
+ stopping_criteria = self .summarize_stopping_criteria ,
301
+ )[len (prompt ) :]
248
302
249
303
async def summarize_stream (
250
- self , input_text : str , template : str , signature : str , style : str , parameters : Dict [str , Any ]
304
+ self ,
305
+ input_text : str ,
306
+ template : str ,
307
+ signature : str ,
308
+ style : str ,
309
+ parameters : Dict [str , Any ],
251
310
):
252
311
prompt = self .summarization_input (input_text , signature , style )
253
312
splited_template = re .split (r"\$\{.*\}" , template )
254
313
splited_template = [prompt ] + splited_template
255
314
256
315
async for token in self .generate_between_stream (
257
- splited_template , parameters , stopping_criteria = self .summarize_stopping_criteria
316
+ splited_template ,
317
+ parameters ,
318
+ stopping_criteria = self .summarize_stopping_criteria ,
258
319
):
259
320
yield token
260
321
0 commit comments