Skip to content

Commit b8a84b8

Browse files
authored
Demo for chunk streaming (openvinotoolkit#1320)
Add python chat example for chunk streaming
1 parent 8a74d24 commit b8a84b8

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

samples/python/multinomial_causal_lm/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ This Python example demonstrates custom detokenization with bufferization. The s
3232

3333
To address this, the detokenizer needs a larger context. We accumulate tokens in a tokens_cache buffer and decode multiple tokens together, adding the text to the streaming queue only when a complete decoded chunk is ready. We run a separate thread to print all new elements arriving in this queue from the generation pipeline. Each generated chunk of text is put into a synchronized queue, ensuring that all put and get operations are thread-safe and blocked until they can proceed.
3434

35+
At the same time, in order to optimize the performance in streaming mode, we provide the Chuck Streaming. Chunk streaming has significant benefits to very small LLM for streaming generate token rate improvement. It does sampling once after several token generation. We can use the tokens_len parameter to control the number of tokens in the token_cache before sampling.
36+
3537
### Troubleshooting
3638

3739
#### Unicode characters encoding error on Windows

samples/python/multinomial_causal_lm/multinomial_causal_lm.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -120,23 +120,41 @@ def end(self):
120120
self.put_word(None)
121121

122122

123+
class ChunkStreamer(IterableStreamer):
124+
125+
def __init__(self, tokenizer, tokens_len):
126+
super().__init__(tokenizer)
127+
self.tokens_len = tokens_len
128+
129+
def put(self, token_id: int) -> bool:
130+
if (len(self.tokens_cache) + 1) % self.tokens_len != 0:
131+
self.tokens_cache.append(token_id)
132+
return False
133+
return super().put(token_id)
134+
135+
123136
def main():
124137
parser = argparse.ArgumentParser()
125138
parser.add_argument('model_dir')
126139
parser.add_argument('prompt')
127140
args = parser.parse_args()
128141

129142
device = 'CPU' # GPU can be used as well
143+
tokens_len = 10 # chunk size
130144
pipe = openvino_genai.LLMPipeline(args.model_dir, device)
131-
132-
text_print_streamer = IterableStreamer(pipe.get_tokenizer())
145+
146+
text_print_streamer = ChunkStreamer(
147+
pipe.get_tokenizer(),
148+
tokens_len
149+
)
150+
133151
def token_printer():
134152
# Getting next elements from iterable will be blocked until a new token is available.
135153
for word in text_print_streamer:
136154
print(word, end='', flush=True)
137155
printer_thread = threading.Thread(target=token_printer, daemon=True)
138156
printer_thread.start()
139-
157+
140158
config = openvino_genai.GenerationConfig()
141159
config.max_new_tokens = 100
142160
config.do_sample = True
@@ -148,5 +166,6 @@ def token_printer():
148166
pipe.generate(args.prompt, config, text_print_streamer)
149167
printer_thread.join()
150168

169+
151170
if '__main__' == __name__:
152171
main()

0 commit comments

Comments
 (0)