5
5
from pathlib import Path
6
6
from threading import Thread
7
7
from time import time
8
- from typing import Any , Callable , Container , Dict , Generator , List , Optional , Type , Union
8
+ from typing import (
9
+ Any ,
10
+ Callable ,
11
+ Container ,
12
+ Dict ,
13
+ Generator ,
14
+ List ,
15
+ Optional ,
16
+ Type ,
17
+ Union ,
18
+ )
9
19
10
20
import torch
11
21
from fastapi import Request
30
40
model_dir = Path ("models" )
31
41
model_dir .mkdir (exist_ok = True )
32
42
33
- SUMMARIZE_INSTRUCTION = "{function}\n \n # The function with {style} style docstring\n \n {signature}\n "
43
+ SUMMARIZE_INSTRUCTION = (
44
+ "{function}\n \n # The function with {style} style docstring\n \n {signature}\n "
45
+ )
34
46
SUMMARIZE_STOP_TOKENS = ("\r \n " , "\n " )
35
47
36
48
37
49
def get_model_class (checkpoint : Union [str , Path ]) -> Type [OVModel ]:
38
50
config = AutoConfig .from_pretrained (checkpoint )
39
51
architecture : str = config .architectures [0 ]
40
- if architecture .endswith ("ConditionalGeneration" ) or architecture .endswith ("Seq2SeqLM" ):
52
+ if architecture .endswith ("ConditionalGeneration" ) or architecture .endswith (
53
+ "Seq2SeqLM"
54
+ ):
41
55
return OVModelForSeq2SeqLM
42
56
43
57
return OVModelForCausalLM
@@ -48,16 +62,27 @@ def get_model(checkpoint: str, device: str = "CPU") -> OVModel:
48
62
model_path = model_dir / Path (checkpoint )
49
63
if model_path .exists ():
50
64
model_class = get_model_class (model_path )
51
- model = model_class .from_pretrained (model_path , ov_config = ov_config , compile = False , device = device )
65
+ model = model_class .from_pretrained (
66
+ model_path , ov_config = ov_config , compile = False , device = device
67
+ )
52
68
else :
53
69
model_class = get_model_class (checkpoint )
54
70
try :
55
71
model = model_class .from_pretrained (
56
- checkpoint , ov_config = ov_config , compile = False , device = device , trust_remote_code = True
72
+ checkpoint ,
73
+ ov_config = ov_config ,
74
+ compile = False ,
75
+ device = device ,
76
+ trust_remote_code = True ,
57
77
)
58
78
except EntryNotFoundError :
59
79
model = model_class .from_pretrained (
60
- checkpoint , ov_config = ov_config , export = True , compile = False , device = device , trust_remote_code = True
80
+ checkpoint ,
81
+ ov_config = ov_config ,
82
+ export = True ,
83
+ compile = False ,
84
+ device = device ,
85
+ trust_remote_code = True ,
61
86
)
62
87
model .save_pretrained (model_path )
63
88
model .compile ()
@@ -72,13 +97,29 @@ class GeneratorFunctor:
72
97
def __call__ (self , input_text : str , parameters : Dict [str , Any ]) -> str :
73
98
raise NotImplementedError
74
99
75
- async def generate_stream (self , input_text : str , parameters : Dict [str , Any ], request : Request ):
100
+ async def generate_stream (
101
+ self , input_text : str , parameters : Dict [str , Any ], request : Request
102
+ ):
76
103
raise NotImplementedError
77
104
78
- def summarize (self , input_text : str , template : str , signature : str , style : str , parameters : Dict [str , Any ]):
105
+ def summarize (
106
+ self ,
107
+ input_text : str ,
108
+ template : str ,
109
+ signature : str ,
110
+ style : str ,
111
+ parameters : Dict [str , Any ],
112
+ ):
79
113
raise NotImplementedError
80
114
81
- def summarize_stream (self , input_text : str , template : str , signature : str , style : str , parameters : Dict [str , Any ]):
115
+ def summarize_stream (
116
+ self ,
117
+ input_text : str ,
118
+ template : str ,
119
+ signature : str ,
120
+ style : str ,
121
+ parameters : Dict [str , Any ],
122
+ ):
82
123
raise NotImplementedError
83
124
84
125
@@ -113,9 +154,14 @@ def __init__(
113
154
if summarize_stop_tokens :
114
155
stop_tokens = []
115
156
for token_id in self .tokenizer .vocab .values ():
116
- if any (stop_word in self .tokenizer .decode (token_id ) for stop_word in summarize_stop_tokens ):
157
+ if any (
158
+ stop_word in self .tokenizer .decode (token_id )
159
+ for stop_word in summarize_stop_tokens
160
+ ):
117
161
stop_tokens .append (token_id )
118
- self .summarize_stopping_criteria = StoppingCriteriaList ([StopOnTokens (stop_tokens )])
162
+ self .summarize_stopping_criteria = StoppingCriteriaList (
163
+ [StopOnTokens (stop_tokens )]
164
+ )
119
165
120
166
def __call__ (self , input_text : str , parameters : Dict [str , Any ]) -> str :
121
167
input_ids = self .tokenizer .encode (input_text , return_tensors = "pt" )
@@ -126,20 +172,36 @@ def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
126
172
stopping_criteria = StoppingCriteriaList ([stop_on_time ])
127
173
128
174
prompt_len = input_ids .shape [- 1 ]
129
- config = GenerationConfig .from_dict ({** self .generation_config .to_dict (), ** parameters })
175
+ config = GenerationConfig .from_dict (
176
+ {** self .generation_config .to_dict (), ** parameters }
177
+ )
130
178
output_ids = self .model .generate (
131
- input_ids , generation_config = config , stopping_criteria = stopping_criteria , ** self .assistant_model_config
179
+ input_ids ,
180
+ generation_config = config ,
181
+ stopping_criteria = stopping_criteria ,
182
+ ** self .assistant_model_config ,
132
183
)[0 ][prompt_len :]
133
- logger .info (f"Number of input tokens: { prompt_len } ; generated { len (output_ids )} tokens" )
134
- return self .tokenizer .decode (output_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False )
184
+ logger .info (
185
+ f"Number of input tokens: { prompt_len } ; generated { len (output_ids )} tokens"
186
+ )
187
+ return self .tokenizer .decode (
188
+ output_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False
189
+ )
135
190
136
191
async def generate_stream (
137
- self , input_text : str , parameters : Dict [str , Any ], request : Optional [Request ] = None
192
+ self ,
193
+ input_text : str ,
194
+ parameters : Dict [str , Any ],
195
+ request : Optional [Request ] = None ,
138
196
) -> Generator [str , None , None ]:
139
197
input_ids = self .tokenizer .encode (input_text , return_tensors = "pt" )
140
- streamer = TextIteratorStreamer (self .tokenizer , skip_prompt = True , skip_special_tokens = True )
198
+ streamer = TextIteratorStreamer (
199
+ self .tokenizer , skip_prompt = True , skip_special_tokens = True
200
+ )
141
201
parameters ["streamer" ] = streamer
142
- config = GenerationConfig .from_dict ({** self .generation_config .to_dict (), ** parameters })
202
+ config = GenerationConfig .from_dict (
203
+ {** self .generation_config .to_dict (), ** parameters }
204
+ )
143
205
144
206
stop_on_tokens = StopOnTokens ([])
145
207
@@ -180,7 +242,9 @@ def generate_between(
180
242
parameters : Dict [str , Any ],
181
243
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
182
244
) -> str :
183
- config = GenerationConfig .from_dict ({** self .generation_config .to_dict (), ** parameters })
245
+ config = GenerationConfig .from_dict (
246
+ {** self .generation_config .to_dict (), ** parameters }
247
+ )
184
248
185
249
prompt = torch .tensor ([[]], dtype = torch .int64 )
186
250
buffer = StringIO ()
@@ -192,13 +256,20 @@ def generate_between(
192
256
prev_len = prompt .shape [- 1 ]
193
257
194
258
prompt = self .model .generate (
195
- prompt , generation_config = config , stopping_criteria = stopping_criteria , ** self .assistant_model_config
259
+ prompt ,
260
+ generation_config = config ,
261
+ stopping_criteria = stopping_criteria ,
262
+ ** self .assistant_model_config ,
196
263
)[
197
264
:, :- 1
198
265
] # skip the last token - stop token
199
266
200
- decoded = self .tokenizer .decode (prompt [0 , prev_len :], skip_special_tokens = True )
201
- buffer .write (decoded .lstrip (" " )) # hack to delete leadding spaces if there are any
267
+ decoded = self .tokenizer .decode (
268
+ prompt [0 , prev_len :], skip_special_tokens = True
269
+ )
270
+ buffer .write (
271
+ decoded .lstrip (" " )
272
+ ) # hack to delete leadding spaces if there are any
202
273
buffer .write (input_parts [- 1 ])
203
274
return buffer .getvalue ()
204
275
@@ -208,7 +279,9 @@ async def generate_between_stream(
208
279
parameters : Dict [str , Any ],
209
280
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
210
281
) -> Generator [str , None , None ]:
211
- config = GenerationConfig .from_dict ({** self .generation_config .to_dict (), ** parameters })
282
+ config = GenerationConfig .from_dict (
283
+ {** self .generation_config .to_dict (), ** parameters }
284
+ )
212
285
213
286
prompt = self .tokenizer .encode (input_parts [0 ], return_tensors = "pt" )
214
287
for text_input in input_parts [1 :- 1 ]:
@@ -219,12 +292,17 @@ async def generate_between_stream(
219
292
prev_len = prompt .shape [- 1 ]
220
293
221
294
prompt = self .model .generate (
222
- prompt , generation_config = config , stopping_criteria = stopping_criteria , ** self .assistant_model_config
295
+ prompt ,
296
+ generation_config = config ,
297
+ stopping_criteria = stopping_criteria ,
298
+ ** self .assistant_model_config ,
223
299
)[
224
300
:, :- 1
225
301
] # skip the last token - stop token
226
302
227
- decoded = self .tokenizer .decode (prompt [0 , prev_len :], skip_special_tokens = True )
303
+ decoded = self .tokenizer .decode (
304
+ prompt [0 , prev_len :], skip_special_tokens = True
305
+ )
228
306
yield decoded .lstrip (" " ) # hack to delete leadding spaces if there are any
229
307
230
308
yield input_parts [- 1 ]
@@ -237,24 +315,40 @@ def summarization_input(function: str, signature: str, style: str) -> str:
237
315
signature = signature ,
238
316
)
239
317
240
- def summarize (self , input_text : str , template : str , signature : str , style : str , parameters : Dict [str , Any ]) -> str :
318
+ def summarize (
319
+ self ,
320
+ input_text : str ,
321
+ template : str ,
322
+ signature : str ,
323
+ style : str ,
324
+ parameters : Dict [str , Any ],
325
+ ) -> str :
241
326
prompt = self .summarization_input (input_text , signature , style )
242
327
splited_template = re .split (r"\$\{.*\}" , template )
243
328
splited_template [0 ] = prompt + splited_template [0 ]
244
329
245
- return self .generate_between (splited_template , parameters , stopping_criteria = self .summarize_stopping_criteria )[
246
- len (prompt ) :
247
- ]
330
+ return self .generate_between (
331
+ splited_template ,
332
+ parameters ,
333
+ stopping_criteria = self .summarize_stopping_criteria ,
334
+ )[len (prompt ) :]
248
335
249
336
async def summarize_stream (
250
- self , input_text : str , template : str , signature : str , style : str , parameters : Dict [str , Any ]
337
+ self ,
338
+ input_text : str ,
339
+ template : str ,
340
+ signature : str ,
341
+ style : str ,
342
+ parameters : Dict [str , Any ],
251
343
):
252
344
prompt = self .summarization_input (input_text , signature , style )
253
345
splited_template = re .split (r"\$\{.*\}" , template )
254
346
splited_template = [prompt ] + splited_template
255
347
256
348
async for token in self .generate_between_stream (
257
- splited_template , parameters , stopping_criteria = self .summarize_stopping_criteria
349
+ splited_template ,
350
+ parameters ,
351
+ stopping_criteria = self .summarize_stopping_criteria ,
258
352
):
259
353
yield token
260
354
@@ -279,7 +373,9 @@ def __init__(self, token_ids: List[int]) -> None:
279
373
self .cancelled = False
280
374
self .token_ids = torch .tensor (token_ids , requires_grad = False )
281
375
282
- def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor , ** kwargs ) -> bool :
376
+ def __call__ (
377
+ self , input_ids : torch .LongTensor , scores : torch .FloatTensor , ** kwargs
378
+ ) -> bool :
283
379
if self .cancelled :
284
380
return True
285
381
return torch .any (torch .eq (input_ids [0 , - 1 ], self .token_ids )).item ()
@@ -304,4 +400,6 @@ def __call__(self, *args, **kwargs) -> bool:
304
400
self .time_for_prev_token = elapsed
305
401
self .time = current_time
306
402
307
- return self .stop_until < current_time + self .time_for_prev_token * self .grow_factor
403
+ return (
404
+ self .stop_until < current_time + self .time_for_prev_token * self .grow_factor
405
+ )
0 commit comments