Skip to content

Commit 415d910

Browse files
Make TextStreamer public & add unit-tests (#1700)
- Check whether text is incomplete should be done as early as possible in order to write to `m_decoded_lengths` relevant information that text in the position of this token is incomplete (is expressed with `-1`). - Made `TextStreamer` public. - Also added unit-tests for `TextStreamer` ~which use pybind to make available this object from Python for test purposes. But `TextCallbackStreamer` is still private and visible only for developers.~ Ticket: CVS-148635 CVS-160780 --------- Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
1 parent be38f4d commit 415d910

23 files changed

+318
-131
lines changed

.github/labeler.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
'category: LLM':
1919
- 'src/cpp/include/openvino/genai/llm_pipeline.hpp'
20+
- 'src/cpp/include/openvino/genai/text_streamer.hpp'
2021
- 'src/cpp/src/llm_pipeline.cpp'
2122
- 'src/cpp/src/lm_encoding.hpp'
2223
- 'src/cpp/src/lm_encoding.cpp'
2324
- 'src/cpp/src/llm_pipeline_base.hpp'
2425
- 'src/cpp/src/llm_pipeline_static.hpp'
2526
- 'src/cpp/src/llm_pipeline_static.cpp'
26-
- 'src/cpp/src/text_callback_streamer.cpp'
27-
- 'src/cpp/src/text_callback_streamer.hpp'
27+
- 'src/cpp/src/text_streamer.cpp'
2828
- 'src/python/py_llm_pipeline.cpp'
2929
- 'tests/python_tests/test_llm_pipeline.py'
3030

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (C) 2023-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#pragma once
5+
6+
#include "openvino/genai/streamer_base.hpp"
7+
#include "openvino/genai/tokenizer.hpp"
8+
9+
namespace ov {
10+
namespace genai {
11+
12+
using CallbackTypeVariant = std::variant<bool, StreamingStatus>;
13+
14+
/**
15+
* @brief TextStreamer is used to decode tokens into text and call a user-defined callback function.
16+
*
17+
* @param tokenizer Tokenizer object to decode tokens into text.
18+
* @param callback User-defined callback function to process the decoded text, callback should return
19+
* either boolean flag or StreamingStatus.
20+
*/
21+
class OPENVINO_GENAI_EXPORTS TextStreamer: public StreamerBase {
22+
StreamingStatus set_streaming_status(CallbackTypeVariant callback_status);
23+
24+
std::function<CallbackTypeVariant(std::string)> m_subword_callback = [](std::string words)->bool { return false; };
25+
StreamingStatus run_callback_if_needed(const std::string& text);
26+
27+
public:
28+
StreamingStatus write(int64_t token) override;
29+
30+
void end() override;
31+
32+
TextStreamer(const Tokenizer& tokenizer, std::function<CallbackTypeVariant(std::string)> callback);
33+
34+
protected:
35+
Tokenizer m_tokenizer;
36+
std::vector<int64_t> m_tokens_cache;
37+
std::vector<int64_t> m_decoded_lengths;
38+
size_t m_printed_len = 0;
39+
};
40+
41+
} // namespace genai
42+
} // namespace ov

src/cpp/src/continuous_batching_impl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <atomic>
55
#include <thread>
66

7-
#include "text_callback_streamer.hpp"
7+
#include "openvino/genai/text_streamer.hpp"
88
#include "continuous_batching_impl.hpp"
99
#include "utils.hpp"
1010
#include "paged_attention_transformations.hpp"

src/cpp/src/llm_pipeline_stateful.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
#include "lora_helper.hpp"
88
#include "lm_encoding.hpp"
9-
#include "text_callback_streamer.hpp"
9+
#include "openvino/genai/text_streamer.hpp"
10+
1011
#include "utils.hpp"
1112

1213
namespace ov::genai {

src/cpp/src/llm_pipeline_static.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
#include "openvino/runtime/properties.hpp"
2323
#include "openvino/runtime/intel_npu/properties.hpp"
2424
#include "openvino/core/parallel.hpp"
25+
#include "openvino/genai/text_streamer.hpp"
2526

2627
#include <jinja2cpp/user_callable.h>
2728

28-
#include "text_callback_streamer.hpp"
29+
2930
#include "json_utils.hpp"
3031
#include "utils.hpp"
3132

src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
#include "utils.hpp"
77
#include "prompt_lookup_impl.hpp"
8-
#include "text_callback_streamer.hpp"
8+
#include "openvino/genai/text_streamer.hpp"
9+
910

1011
namespace ov::genai {
1112
template<class... Ts> struct overloaded : Ts... {using Ts::operator()...;};

src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include <thread>
55

6-
#include "text_callback_streamer.hpp"
6+
#include "openvino/genai/text_streamer.hpp"
77
#include "speculative_decoding_impl.hpp"
88
#include "paged_attention_transformations.hpp"
99
#include "utils.hpp"

src/cpp/src/text_callback_streamer.hpp

-37
This file was deleted.

src/cpp/src/text_callback_streamer.cpp src/cpp/src/text_streamer.cpp

+25-16
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
// Copyright (C) 2023-2025 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33

4-
#include "text_callback_streamer.hpp"
4+
#include "openvino/genai/text_streamer.hpp"
5+
56

67
namespace ov {
78
namespace genai {
89

9-
TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function<ov::genai::CallbackTypeVariant(std::string)> callback) {
10+
TextStreamer::TextStreamer(const Tokenizer& tokenizer, std::function<ov::genai::CallbackTypeVariant(std::string)> callback) {
1011
m_tokenizer = tokenizer;
11-
m_on_finalized_subword_callback = callback;
12+
m_subword_callback = callback;
1213
}
1314

14-
StreamingStatus TextCallbackStreamer::write(int64_t token) {
15+
StreamingStatus TextStreamer::write(int64_t token) {
1516
std::stringstream res;
1617
m_tokens_cache.push_back(token);
1718
std::string text = m_tokenizer.decode(m_tokens_cache);
@@ -23,21 +24,21 @@ StreamingStatus TextCallbackStreamer::write(int64_t token) {
2324
m_tokens_cache.clear();
2425
m_decoded_lengths.clear();
2526
m_printed_len = 0;
26-
return set_streaming_status(m_on_finalized_subword_callback(res.str()));
27+
return run_callback_if_needed(res.str());
2728
}
2829

30+
constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error.
31+
if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) {
32+
m_decoded_lengths[m_decoded_lengths.size() - 1] = -1;
33+
// Don't print incomplete text
34+
return run_callback_if_needed(res.str());
35+
}
2936
constexpr size_t delay_n_tokens = 3;
3037
// In some cases adding the next token can shorten the text,
3138
// e.g. when apostrophe removing regex had worked after adding new tokens.
3239
// Printing several last tokens is delayed.
3340
if (m_decoded_lengths.size() < delay_n_tokens) {
34-
return set_streaming_status(m_on_finalized_subword_callback(res.str()));
35-
}
36-
constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error.
37-
if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) {
38-
m_decoded_lengths[m_decoded_lengths.size() - 1] = -1;
39-
// Don't print incomplete text
40-
return set_streaming_status(m_on_finalized_subword_callback(res.str()));
41+
return run_callback_if_needed(res.str());
4142
}
4243
auto print_until = m_decoded_lengths[m_decoded_lengths.size() - delay_n_tokens];
4344
if (print_until != -1 && print_until > m_printed_len) {
@@ -47,17 +48,25 @@ StreamingStatus TextCallbackStreamer::write(int64_t token) {
4748
m_printed_len = print_until;
4849
}
4950

50-
return set_streaming_status(m_on_finalized_subword_callback(res.str()));
51+
return run_callback_if_needed(res.str());
5152
}
5253

53-
StreamingStatus TextCallbackStreamer::set_streaming_status(CallbackTypeVariant callback_status) {
54+
StreamingStatus TextStreamer::set_streaming_status(CallbackTypeVariant callback_status) {
5455
if (auto res = std::get_if<StreamingStatus>(&callback_status))
5556
return *res;
5657
else
5758
return std::get<bool>(callback_status) ? StreamingStatus::STOP : StreamingStatus::RUNNING;
5859
}
5960

60-
void TextCallbackStreamer::end() {
61+
StreamingStatus TextStreamer::run_callback_if_needed(const std::string& text) {
62+
if (text.empty()) {
63+
return StreamingStatus::RUNNING;
64+
} else {
65+
return set_streaming_status(m_subword_callback(text));
66+
}
67+
}
68+
69+
void TextStreamer::end() {
6170
std::stringstream res;
6271
std::string text = m_tokenizer.decode(m_tokens_cache);
6372
if (text.size() <= m_printed_len)
@@ -66,7 +75,7 @@ void TextCallbackStreamer::end() {
6675
m_tokens_cache.clear();
6776
m_decoded_lengths.clear();
6877
m_printed_len = 0;
69-
m_on_finalized_subword_callback(res.str());
78+
m_subword_callback(res.str());
7079
return;
7180
}
7281

src/cpp/src/threaded_streamer.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
#include <thread>
77

88
#include "openvino/genai/llm_pipeline.hpp"
9+
#include "openvino/genai/text_streamer.hpp"
910
#include "openvino/genai/tokenizer.hpp"
1011
#include "synchronized_queue.hpp"
11-
#include "text_callback_streamer.hpp"
1212
#include "utils.hpp"
1313

1414
namespace ov {

src/cpp/src/utils.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
#include "openvino/op/slice.hpp"
1616
#include "openvino/op/tanh.hpp"
1717
#include "openvino/op/transpose.hpp"
18+
#include "openvino/genai/text_streamer.hpp"
1819

19-
#include "text_callback_streamer.hpp"
2020

2121
#include "sampler.hpp"
2222

@@ -160,10 +160,10 @@ std::shared_ptr<StreamerBase> create_streamer(StreamerVariant streamer, Tokenize
160160
return streamer;
161161
},
162162
[&tokenizer = tokenizer](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
163-
return std::make_unique<TextCallbackStreamer>(tokenizer, streamer);
163+
return std::make_unique<TextStreamer>(tokenizer, streamer);
164164
},
165165
[&tokenizer = tokenizer](const std::function<ov::genai::StreamingStatus(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
166-
return std::make_unique<TextCallbackStreamer>(tokenizer, streamer);
166+
return std::make_unique<TextStreamer>(tokenizer, streamer);
167167
}
168168
}, streamer);
169169

src/cpp/src/visual_language/pipeline.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
#include "openvino/genai/visual_language/pipeline.hpp"
88
#include "openvino/genai/visual_language/perf_metrics.hpp"
99
#include "openvino/genai/tokenizer.hpp"
10+
#include "openvino/genai/text_streamer.hpp"
1011

1112
#include "visual_language/vlm_config.hpp"
1213
#include "visual_language/inputs_embedder.hpp"
1314
#include "visual_language/embedding_model.hpp"
1415

1516
#include "sampler.hpp"
16-
#include "text_callback_streamer.hpp"
1717
#include "utils.hpp"
1818
#include "lm_encoding.hpp"
1919

src/cpp/src/whisper/streamer.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
#include "streamer.hpp"
55

6-
#include "text_callback_streamer.hpp"
6+
#include "openvino/genai/text_streamer.hpp"
77

88
namespace ov {
99
namespace genai {
1010

1111
StreamingStatus ChunkTextCallbackStreamer::write(int64_t token) {
12-
return ov::genai::TextCallbackStreamer::write(token);
12+
return ov::genai::TextStreamer::write(token);
1313
}
1414

1515
StreamingStatus ChunkTextCallbackStreamer::write_chunk(std::vector<int64_t> tokens) {
@@ -21,11 +21,11 @@ StreamingStatus ChunkTextCallbackStreamer::write_chunk(std::vector<int64_t> toke
2121
m_tokens_cache.insert(m_tokens_cache.end(), tokens.begin(), tokens.end() - 1);
2222
}
2323

24-
return ov::genai::TextCallbackStreamer::write(tokens.back());
24+
return ov::genai::TextStreamer::write(tokens.back());
2525
}
2626

2727
void ChunkTextCallbackStreamer::end() {
28-
ov::genai::TextCallbackStreamer::end();
28+
ov::genai::TextStreamer::end();
2929
}
3030

3131
} // namespace genai

src/cpp/src/whisper/streamer.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55

66
#include "openvino/genai/tokenizer.hpp"
77
#include "openvino/genai/whisper_pipeline.hpp"
8-
#include "text_callback_streamer.hpp"
8+
#include "openvino/genai/text_streamer.hpp"
99

1010
namespace ov {
1111
namespace genai {
1212

13-
class ChunkTextCallbackStreamer : private TextCallbackStreamer, public ChunkStreamerBase {
13+
class ChunkTextCallbackStreamer : private TextStreamer, public ChunkStreamerBase {
1414
public:
1515
StreamingStatus write(int64_t token) override;
1616
StreamingStatus write_chunk(std::vector<int64_t> tokens) override;
1717
void end() override;
1818

1919
ChunkTextCallbackStreamer(const Tokenizer& tokenizer, std::function<ov::genai::CallbackTypeVariant(std::string)> callback)
20-
: TextCallbackStreamer(tokenizer, callback){};
20+
: TextStreamer(tokenizer, callback){};
2121
};
2222

2323
} // namespace genai

src/python/openvino_genai/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
PerfMetrics,
1717
StreamerBase,
1818
get_version,
19-
StreamingStatus
19+
StreamingStatus,
20+
TextStreamer
2021
)
2122

2223
__version__ = get_version()

src/python/openvino_genai/__init__.pyi

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ from openvino_genai.py_openvino_genai import StreamerBase
3333
from openvino_genai.py_openvino_genai import StreamingStatus
3434
from openvino_genai.py_openvino_genai import T5EncoderModel
3535
from openvino_genai.py_openvino_genai import Text2ImagePipeline
36+
from openvino_genai.py_openvino_genai import TextStreamer
3637
from openvino_genai.py_openvino_genai import TokenizedInputs
3738
from openvino_genai.py_openvino_genai import Tokenizer
3839
from openvino_genai.py_openvino_genai import TorchGenerator
@@ -46,5 +47,5 @@ from openvino_genai.py_openvino_genai import draft_model
4647
from openvino_genai.py_openvino_genai import get_version
4748
import os as os
4849
from . import py_openvino_genai
49-
__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationResult', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'PerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version', 'openvino', 'os', 'py_openvino_genai']
50+
__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationResult', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'PerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'T5EncoderModel', 'Text2ImagePipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version', 'openvino', 'os', 'py_openvino_genai']
5051
__version__: str

src/python/openvino_genai/py_openvino_genai.pyi

+16-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from __future__ import annotations
55
import openvino._pyopenvino
66
import os
77
import typing
8-
__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version']
8+
__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'T5EncoderModel', 'Text2ImagePipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version']
99
class Adapter:
1010
"""
1111
Immutable LoRA Adapter that carries the adaptation matrices and serves as unique adapter identifier.
@@ -1706,6 +1706,21 @@ class Text2ImagePipeline:
17061706
...
17071707
def set_scheduler(self, scheduler: Scheduler) -> None:
17081708
...
1709+
class TextStreamer(StreamerBase):
1710+
"""
1711+
1712+
TextStreamer is used to decode tokens into text and call a user-defined callback function.
1713+
1714+
tokenizer: Tokenizer object to decode tokens into text.
1715+
callback: User-defined callback function to process the decoded text, callback should return either boolean flag or StreamingStatus.
1716+
1717+
"""
1718+
def __init__(self, tokenizer: Tokenizer, callback: typing.Callable[[str], bool | StreamingStatus]) -> None:
1719+
...
1720+
def end(self) -> None:
1721+
...
1722+
def write(self, token: int) -> StreamingStatus:
1723+
...
17091724
class TokenizedInputs:
17101725
attention_mask: openvino._pyopenvino.Tensor
17111726
input_ids: openvino._pyopenvino.Tensor

0 commit comments

Comments
 (0)