|
2 | 2 | # Copyright (C) 2023-2024 Intel Corporation
|
3 | 3 | # SPDX-License-Identifier: Apache-2.0
|
4 | 4 | from pathlib import Path
|
5 |
| -from transformers import AutoConfig, AutoProcessor |
| 5 | +from transformers import AutoConfig, AutoProcessor, AutoTokenizer |
6 | 6 | from openvino.runtime import Core
|
7 | 7 | import openvino as ov
|
8 | 8 | import logging as log
|
|
11 | 11 | import json
|
12 | 12 | import types
|
13 | 13 | from llm_bench_utils.hook_common import get_bench_hook
|
14 |
| -from llm_bench_utils.config_class import OV_MODEL_CLASSES_MAPPING, TOKENIZE_CLASSES_MAPPING, DEFAULT_MODEL_CLASSES, IMAGE_GEN_CLS |
| 14 | +from llm_bench_utils.config_class import ( |
| 15 | + OV_MODEL_CLASSES_MAPPING, |
| 16 | + TOKENIZE_CLASSES_MAPPING, |
| 17 | + DEFAULT_MODEL_CLASSES, |
| 18 | + IMAGE_GEN_CLS |
| 19 | +) |
15 | 20 | import openvino.runtime.opset13 as opset
|
16 | 21 | from transformers import pipeline
|
| 22 | +import openvino_genai as ov_genai |
| 23 | +import queue |
| 24 | +from transformers.generation.streamers import BaseStreamer |
17 | 25 |
|
18 | 26 |
|
19 | 27 | def generate_simplified(self, *args, **kwargs):
|
@@ -525,3 +533,231 @@ def is_genai_available(log_msg=False):
|
525 | 533 | log.warning(ex)
|
526 | 534 | return False
|
527 | 535 | return True
|
| 536 | + |
| 537 | + |
| 538 | +class GenaiChunkStreamer(ov_genai.StreamerBase): |
| 539 | + """ |
| 540 | + A custom streamer class for handling token streaming and detokenization with buffering. |
| 541 | +
|
| 542 | + Attributes: |
| 543 | + tokenizer (Tokenizer): The tokenizer used for encoding and decoding tokens. |
| 544 | + tokens_cache (list): A buffer to accumulate tokens for detokenization. |
| 545 | + text_queue (Queue): A synchronized queue for storing decoded text chunks. |
| 546 | + print_len (int): The length of the printed text to manage incremental decoding. |
| 547 | + """ |
| 548 | + |
| 549 | + def __init__(self, tokenizer, tokens_len=1): |
| 550 | + """ |
| 551 | + Initializes the IterableStreamer with the given tokenizer. |
| 552 | +
|
| 553 | + Args: |
| 554 | + tokenizer (Tokenizer): The tokenizer to use for encoding and decoding tokens. |
| 555 | + """ |
| 556 | + super().__init__() |
| 557 | + self.tokenizer = tokenizer |
| 558 | + self.tokens_cache = [] |
| 559 | + self.text_queue = queue.Queue() |
| 560 | + self.print_len = 0 |
| 561 | + self.tokens_len = tokens_len |
| 562 | + |
| 563 | + def __iter__(self): |
| 564 | + """ |
| 565 | + Returns the iterator object itself. |
| 566 | + """ |
| 567 | + return self |
| 568 | + |
| 569 | + def __next__(self): |
| 570 | + """ |
| 571 | + Returns the next value from the text queue. |
| 572 | +
|
| 573 | + Returns: |
| 574 | + str: The next decoded text chunk. |
| 575 | +
|
| 576 | + Raises: |
| 577 | + StopIteration: If there are no more elements in the queue. |
| 578 | + """ |
| 579 | + value = self.text_queue.get() # get() will be blocked until a token is available. |
| 580 | + if value is None: |
| 581 | + raise StopIteration |
| 582 | + return value |
| 583 | + |
| 584 | + def get_stop_flag(self): |
| 585 | + """ |
| 586 | + Checks whether the generation process should be stopped. |
| 587 | +
|
| 588 | + Returns: |
| 589 | + bool: Always returns False in this implementation. |
| 590 | + """ |
| 591 | + return False |
| 592 | + |
| 593 | + def put_word(self, word: str): |
| 594 | + """ |
| 595 | + Puts a word into the text queue. |
| 596 | +
|
| 597 | + Args: |
| 598 | + word (str): The word to put into the queue. |
| 599 | + """ |
| 600 | + self.text_queue.put(word) |
| 601 | + |
| 602 | + def put(self, token_id: int) -> bool: |
| 603 | + """ |
| 604 | + Processes a token and manages the decoding buffer. Adds decoded text to the queue. |
| 605 | +
|
| 606 | + Args: |
| 607 | + token_id (int): The token_id to process. |
| 608 | +
|
| 609 | + Returns: |
| 610 | + bool: True if generation should be stopped, False otherwise. |
| 611 | + """ |
| 612 | + self.tokens_cache.append(token_id) |
| 613 | + if len(self.tokens_cache) % self.tokens_len == 0: |
| 614 | + text = self.tokenizer.decode(self.tokens_cache) |
| 615 | + |
| 616 | + word = '' |
| 617 | + if len(text) > self.print_len and '\n' == text[-1]: |
| 618 | + # Flush the cache after the new line symbol. |
| 619 | + word = text[self.print_len:] |
| 620 | + self.tokens_cache = [] |
| 621 | + self.print_len = 0 |
| 622 | + elif len(text) >= 3 and text[-3:] == chr(65533): |
| 623 | + # Don't print incomplete text. |
| 624 | + pass |
| 625 | + elif len(text) > self.print_len: |
| 626 | + # It is possible to have a shorter text after adding new token. |
| 627 | + # Print to output only if text lengh is increaesed. |
| 628 | + word = text[self.print_len:] |
| 629 | + self.print_len = len(text) |
| 630 | + self.put_word(word) |
| 631 | + |
| 632 | + if self.get_stop_flag(): |
| 633 | + # When generation is stopped from streamer then end is not called, need to call it here manually. |
| 634 | + self.end() |
| 635 | + return True # True means stop generation |
| 636 | + else: |
| 637 | + return False # False means continue generation |
| 638 | + else: |
| 639 | + return False |
| 640 | + |
| 641 | + def end(self): |
| 642 | + """ |
| 643 | + Flushes residual tokens from the buffer and puts a None value in the queue to signal the end. |
| 644 | + """ |
| 645 | + text = self.tokenizer.decode(self.tokens_cache) |
| 646 | + if len(text) > self.print_len: |
| 647 | + word = text[self.print_len:] |
| 648 | + self.put_word(word) |
| 649 | + self.tokens_cache = [] |
| 650 | + self.print_len = 0 |
| 651 | + self.put_word(None) |
| 652 | + |
| 653 | + |
| 654 | +class OptimumChunkStreamer(BaseStreamer): |
| 655 | + """ |
| 656 | + Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. |
| 657 | + <Tip warning={true}> |
| 658 | + The API for the streamer classes is still under development and may change in the future. |
| 659 | + </Tip> |
| 660 | + Parameters: |
| 661 | + tokenizer (`AutoTokenizer`): |
| 662 | + The tokenized used to decode the tokens. |
| 663 | + skip_prompt (`bool`, *optional*, defaults to `False`): |
| 664 | + Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. |
| 665 | + decode_kwargs (`dict`, *optional*): |
| 666 | + Additional keyword arguments to pass to the tokenizer's `decode` method. |
| 667 | + Examples: |
| 668 | + ```python |
| 669 | + >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
| 670 | + >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2") |
| 671 | + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") |
| 672 | + >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") |
| 673 | + >>> streamer = TextStreamer(tok) |
| 674 | + >>> # Despite returning the usual output, the streamer will also print the generated text to stdout. |
| 675 | + >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) |
| 676 | + An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, |
| 677 | + ``` |
| 678 | + """ |
| 679 | + def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, |
| 680 | + tokens_len: int = 1, **decode_kwargs): |
| 681 | + self.tokenizer = tokenizer |
| 682 | + self.skip_prompt = skip_prompt |
| 683 | + self.decode_kwargs = decode_kwargs |
| 684 | + # variables used in the streaming process |
| 685 | + self.token_cache = [] |
| 686 | + self.print_len = 0 |
| 687 | + self.next_tokens_are_prompt = True |
| 688 | + self.tokens_len = tokens_len |
| 689 | + |
| 690 | + def put(self, value): |
| 691 | + """ |
| 692 | + Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. |
| 693 | + """ |
| 694 | + if len(value.shape) > 1 and value.shape[0] > 1: |
| 695 | + raise ValueError("TextStreamer only supports batch size 1") |
| 696 | + elif len(value.shape) > 1: |
| 697 | + value = value[0] |
| 698 | + if self.skip_prompt and self.next_tokens_are_prompt: |
| 699 | + self.next_tokens_are_prompt = False |
| 700 | + return |
| 701 | + # Add the new token to the cache and decodes the entire thing. |
| 702 | + self.token_cache.extend(value.tolist()) |
| 703 | + if len(self.token_cache) % self.tokens_len == 0: |
| 704 | + text = self.tokenizer.decode( |
| 705 | + self.token_cache, **self.decode_kwargs |
| 706 | + ) |
| 707 | + # After the symbol for a new line, we flush the cache. |
| 708 | + if text.endswith("\n"): |
| 709 | + printable_text = text[self.print_len:] |
| 710 | + self.token_cache = [] |
| 711 | + self.print_len = 0 |
| 712 | + # If the last token is a CJK character, we print the characters. |
| 713 | + elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): |
| 714 | + printable_text = text[self.print_len:] |
| 715 | + self.print_len += len(printable_text) |
| 716 | + # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, |
| 717 | + # which may change with the subsequent token -- there are probably smarter ways to do this!) |
| 718 | + else: |
| 719 | + printable_text = text[self.print_len: text.rfind(" ") + 1] |
| 720 | + self.print_len += len(printable_text) |
| 721 | + self.on_finalized_text(printable_text) |
| 722 | + |
| 723 | + def end(self): |
| 724 | + """Flushes any remaining cache and prints a newline to stdout.""" |
| 725 | + # Flush the cache, if it exists |
| 726 | + if len(self.token_cache) > 0: |
| 727 | + text = self.tokenizer.decode( |
| 728 | + self.token_cache, **self.decode_kwargs |
| 729 | + ) |
| 730 | + printable_text = text[self.print_len:] |
| 731 | + self.token_cache = [] |
| 732 | + self.print_len = 0 |
| 733 | + else: |
| 734 | + printable_text = "" |
| 735 | + self.next_tokens_are_prompt = True |
| 736 | + self.on_finalized_text(printable_text, stream_end=True) |
| 737 | + |
| 738 | + def on_finalized_text(self, text: str, stream_end: bool = False): |
| 739 | + """Prints the new text to stdout. If the stream is ending, also prints a newline.""" |
| 740 | + print(text, flush=True, end="" if not stream_end else None) |
| 741 | + |
| 742 | + def _is_chinese_char(self, cp): |
| 743 | + """Checks whether CP is the codepoint of a CJK character.""" |
| 744 | + # This defines a "chinese character" as anything in the CJK Unicode block: |
| 745 | + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) |
| 746 | + # |
| 747 | + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, |
| 748 | + # despite its name. The modern Korean Hangul alphabet is a different block, |
| 749 | + # as is Japanese Hiragana and Katakana. Those alphabets are used to write |
| 750 | + # space-separated words, so they are not treated specially and handled |
| 751 | + # like the all of the other languages. |
| 752 | + if ( |
| 753 | + (cp >= 0x4E00 and cp <= 0x9FFF) |
| 754 | + or (cp >= 0x3400 and cp <= 0x4DBF) # |
| 755 | + or (cp >= 0x20000 and cp <= 0x2A6DF) # |
| 756 | + or (cp >= 0x2A700 and cp <= 0x2B73F) # |
| 757 | + or (cp >= 0x2B740 and cp <= 0x2B81F) # |
| 758 | + or (cp >= 0x2B820 and cp <= 0x2CEAF) # |
| 759 | + or (cp >= 0xF900 and cp <= 0xFAFF) |
| 760 | + or (cp >= 0x2F800 and cp <= 0x2FA1F) # |
| 761 | + ): # |
| 762 | + return True |
| 763 | + return False |
0 commit comments