Skip to content

Commit d294db9

Browse files
authored
Genai/optimum support streaming output (openvinotoolkit#1290)
Support chunk streaming mode, mainly to reduce the number of decode calls, thereby improving performance
1 parent 3ca509f commit d294db9

File tree

3 files changed

+328
-30
lines changed

3 files changed

+328
-30
lines changed

tools/llm_bench/benchmark.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def get_argprser():
155155
help='Stop the generation even if output token size does not achieve infer_count or max token size ({DEFAULT_OUTPUT_TOKEN_SIZE}}).'
156156
)
157157
parser.add_argument('--set_torch_thread', default=0, type=num_infer_count_type, help='Set the number of Torch thread. ')
158+
parser.add_argument('-tl', '--tokens_len', type=int, required=False, help='The length of tokens print each time in streaming mode, chunk streaming.')
159+
parser.add_argument('--streaming', action='store_true', help='Set whether to use streaming mode, only applicable to LLM.')
158160

159161
return parser.parse_args()
160162

@@ -170,10 +172,23 @@ def get_argprser():
170172

171173
def main():
172174
logging_kwargs = {"encoding": "utf-8"} if sys.version_info[1] > 8 else {}
173-
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=os.environ.get("LOGLEVEL", log.INFO), stream=sys.stdout, **logging_kwargs)
175+
log.basicConfig(
176+
format='[ %(levelname)s ] %(message)s',
177+
level=os.environ.get("LOGLEVEL", log.INFO),
178+
stream=sys.stdout,
179+
**logging_kwargs
180+
)
174181
args = get_argprser()
175-
model_path, framework, model_args, model_name = llm_bench_utils.model_utils.analyze_args(args)
176182

183+
if args.tokens_len is not None and not args.streaming:
184+
log.error("--tokens_len requires --streaming to be set.")
185+
exit(1)
186+
if args.streaming and args.tokens_len is None:
187+
log.error("--streaming requires --tokens_len to be set.")
188+
exit(1)
189+
model_path, framework, model_args, model_name = (
190+
llm_bench_utils.model_utils.analyze_args(args)
191+
)
177192
# Set the device for running OpenVINO backend for torch.compile()
178193
if model_args['torch_compile_backend']:
179194
ov_torch_backend_device = str(args.device)
@@ -208,8 +223,14 @@ def main():
208223
if args.memory_consumption:
209224
mem_consumption.start_collect_mem_consumption_thread()
210225
try:
211-
iter_data_list, pretrain_time, iter_timestamp = CASE_TO_BENCH[model_args['use_case']](
212-
model_path, framework, args.device, model_args, args.num_iters, mem_consumption)
226+
if model_args['use_case'] in ['text_gen', 'code_gen']:
227+
iter_data_list, pretrain_time, iter_timestamp = CASE_TO_BENCH[model_args['use_case']](
228+
model_path, framework, args.device, args.tokens_len, args.streaming, model_args,
229+
args.num_iters, mem_consumption)
230+
else:
231+
iter_data_list, pretrain_time, iter_timestamp = CASE_TO_BENCH[model_args['use_case']](
232+
model_path, framework, args.device, model_args, args.num_iters,
233+
mem_consumption)
213234
if args.report is not None or args.report_json is not None:
214235
model_precision = ''
215236
if framework == 'ov':

tools/llm_bench/llm_bench_utils/ov_utils.py

+238-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (C) 2023-2024 Intel Corporation
33
# SPDX-License-Identifier: Apache-2.0
44
from pathlib import Path
5-
from transformers import AutoConfig, AutoProcessor
5+
from transformers import AutoConfig, AutoProcessor, AutoTokenizer
66
from openvino.runtime import Core
77
import openvino as ov
88
import logging as log
@@ -11,9 +11,17 @@
1111
import json
1212
import types
1313
from llm_bench_utils.hook_common import get_bench_hook
14-
from llm_bench_utils.config_class import OV_MODEL_CLASSES_MAPPING, TOKENIZE_CLASSES_MAPPING, DEFAULT_MODEL_CLASSES, IMAGE_GEN_CLS
14+
from llm_bench_utils.config_class import (
15+
OV_MODEL_CLASSES_MAPPING,
16+
TOKENIZE_CLASSES_MAPPING,
17+
DEFAULT_MODEL_CLASSES,
18+
IMAGE_GEN_CLS
19+
)
1520
import openvino.runtime.opset13 as opset
1621
from transformers import pipeline
22+
import openvino_genai as ov_genai
23+
import queue
24+
from transformers.generation.streamers import BaseStreamer
1725

1826

1927
def generate_simplified(self, *args, **kwargs):
@@ -525,3 +533,231 @@ def is_genai_available(log_msg=False):
525533
log.warning(ex)
526534
return False
527535
return True
536+
537+
538+
class GenaiChunkStreamer(ov_genai.StreamerBase):
539+
"""
540+
A custom streamer class for handling token streaming and detokenization with buffering.
541+
542+
Attributes:
543+
tokenizer (Tokenizer): The tokenizer used for encoding and decoding tokens.
544+
tokens_cache (list): A buffer to accumulate tokens for detokenization.
545+
text_queue (Queue): A synchronized queue for storing decoded text chunks.
546+
print_len (int): The length of the printed text to manage incremental decoding.
547+
"""
548+
549+
def __init__(self, tokenizer, tokens_len=1):
550+
"""
551+
Initializes the IterableStreamer with the given tokenizer.
552+
553+
Args:
554+
tokenizer (Tokenizer): The tokenizer to use for encoding and decoding tokens.
555+
"""
556+
super().__init__()
557+
self.tokenizer = tokenizer
558+
self.tokens_cache = []
559+
self.text_queue = queue.Queue()
560+
self.print_len = 0
561+
self.tokens_len = tokens_len
562+
563+
def __iter__(self):
564+
"""
565+
Returns the iterator object itself.
566+
"""
567+
return self
568+
569+
def __next__(self):
570+
"""
571+
Returns the next value from the text queue.
572+
573+
Returns:
574+
str: The next decoded text chunk.
575+
576+
Raises:
577+
StopIteration: If there are no more elements in the queue.
578+
"""
579+
value = self.text_queue.get() # get() will be blocked until a token is available.
580+
if value is None:
581+
raise StopIteration
582+
return value
583+
584+
def get_stop_flag(self):
585+
"""
586+
Checks whether the generation process should be stopped.
587+
588+
Returns:
589+
bool: Always returns False in this implementation.
590+
"""
591+
return False
592+
593+
def put_word(self, word: str):
594+
"""
595+
Puts a word into the text queue.
596+
597+
Args:
598+
word (str): The word to put into the queue.
599+
"""
600+
self.text_queue.put(word)
601+
602+
def put(self, token_id: int) -> bool:
603+
"""
604+
Processes a token and manages the decoding buffer. Adds decoded text to the queue.
605+
606+
Args:
607+
token_id (int): The token_id to process.
608+
609+
Returns:
610+
bool: True if generation should be stopped, False otherwise.
611+
"""
612+
self.tokens_cache.append(token_id)
613+
if len(self.tokens_cache) % self.tokens_len == 0:
614+
text = self.tokenizer.decode(self.tokens_cache)
615+
616+
word = ''
617+
if len(text) > self.print_len and '\n' == text[-1]:
618+
# Flush the cache after the new line symbol.
619+
word = text[self.print_len:]
620+
self.tokens_cache = []
621+
self.print_len = 0
622+
elif len(text) >= 3 and text[-3:] == chr(65533):
623+
# Don't print incomplete text.
624+
pass
625+
elif len(text) > self.print_len:
626+
# It is possible to have a shorter text after adding new token.
627+
# Print to output only if text lengh is increaesed.
628+
word = text[self.print_len:]
629+
self.print_len = len(text)
630+
self.put_word(word)
631+
632+
if self.get_stop_flag():
633+
# When generation is stopped from streamer then end is not called, need to call it here manually.
634+
self.end()
635+
return True # True means stop generation
636+
else:
637+
return False # False means continue generation
638+
else:
639+
return False
640+
641+
def end(self):
642+
"""
643+
Flushes residual tokens from the buffer and puts a None value in the queue to signal the end.
644+
"""
645+
text = self.tokenizer.decode(self.tokens_cache)
646+
if len(text) > self.print_len:
647+
word = text[self.print_len:]
648+
self.put_word(word)
649+
self.tokens_cache = []
650+
self.print_len = 0
651+
self.put_word(None)
652+
653+
654+
class OptimumChunkStreamer(BaseStreamer):
655+
"""
656+
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
657+
<Tip warning={true}>
658+
The API for the streamer classes is still under development and may change in the future.
659+
</Tip>
660+
Parameters:
661+
tokenizer (`AutoTokenizer`):
662+
The tokenized used to decode the tokens.
663+
skip_prompt (`bool`, *optional*, defaults to `False`):
664+
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
665+
decode_kwargs (`dict`, *optional*):
666+
Additional keyword arguments to pass to the tokenizer's `decode` method.
667+
Examples:
668+
```python
669+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
670+
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
671+
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
672+
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
673+
>>> streamer = TextStreamer(tok)
674+
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
675+
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
676+
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
677+
```
678+
"""
679+
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False,
680+
tokens_len: int = 1, **decode_kwargs):
681+
self.tokenizer = tokenizer
682+
self.skip_prompt = skip_prompt
683+
self.decode_kwargs = decode_kwargs
684+
# variables used in the streaming process
685+
self.token_cache = []
686+
self.print_len = 0
687+
self.next_tokens_are_prompt = True
688+
self.tokens_len = tokens_len
689+
690+
def put(self, value):
691+
"""
692+
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
693+
"""
694+
if len(value.shape) > 1 and value.shape[0] > 1:
695+
raise ValueError("TextStreamer only supports batch size 1")
696+
elif len(value.shape) > 1:
697+
value = value[0]
698+
if self.skip_prompt and self.next_tokens_are_prompt:
699+
self.next_tokens_are_prompt = False
700+
return
701+
# Add the new token to the cache and decodes the entire thing.
702+
self.token_cache.extend(value.tolist())
703+
if len(self.token_cache) % self.tokens_len == 0:
704+
text = self.tokenizer.decode(
705+
self.token_cache, **self.decode_kwargs
706+
)
707+
# After the symbol for a new line, we flush the cache.
708+
if text.endswith("\n"):
709+
printable_text = text[self.print_len:]
710+
self.token_cache = []
711+
self.print_len = 0
712+
# If the last token is a CJK character, we print the characters.
713+
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
714+
printable_text = text[self.print_len:]
715+
self.print_len += len(printable_text)
716+
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
717+
# which may change with the subsequent token -- there are probably smarter ways to do this!)
718+
else:
719+
printable_text = text[self.print_len: text.rfind(" ") + 1]
720+
self.print_len += len(printable_text)
721+
self.on_finalized_text(printable_text)
722+
723+
def end(self):
724+
"""Flushes any remaining cache and prints a newline to stdout."""
725+
# Flush the cache, if it exists
726+
if len(self.token_cache) > 0:
727+
text = self.tokenizer.decode(
728+
self.token_cache, **self.decode_kwargs
729+
)
730+
printable_text = text[self.print_len:]
731+
self.token_cache = []
732+
self.print_len = 0
733+
else:
734+
printable_text = ""
735+
self.next_tokens_are_prompt = True
736+
self.on_finalized_text(printable_text, stream_end=True)
737+
738+
def on_finalized_text(self, text: str, stream_end: bool = False):
739+
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
740+
print(text, flush=True, end="" if not stream_end else None)
741+
742+
def _is_chinese_char(self, cp):
743+
"""Checks whether CP is the codepoint of a CJK character."""
744+
# This defines a "chinese character" as anything in the CJK Unicode block:
745+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
746+
#
747+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
748+
# despite its name. The modern Korean Hangul alphabet is a different block,
749+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
750+
# space-separated words, so they are not treated specially and handled
751+
# like the all of the other languages.
752+
if (
753+
(cp >= 0x4E00 and cp <= 0x9FFF)
754+
or (cp >= 0x3400 and cp <= 0x4DBF) #
755+
or (cp >= 0x20000 and cp <= 0x2A6DF) #
756+
or (cp >= 0x2A700 and cp <= 0x2B73F) #
757+
or (cp >= 0x2B740 and cp <= 0x2B81F) #
758+
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
759+
or (cp >= 0xF900 and cp <= 0xFAFF)
760+
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
761+
): #
762+
return True
763+
return False

0 commit comments

Comments
 (0)