Skip to content

Commit f93b92b

Browse files
authored
Add a choice of how to end streaming from callback: STOP or CANCEL (openvinotoolkit#1476)
1 parent e1339bf commit f93b92b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+681
-305
lines changed

samples/cpp/text_generation/chat_sample.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ int main(int argc, char* argv[]) try {
1515

1616
ov::genai::GenerationConfig config;
1717
config.max_new_tokens = 100;
18-
std::function<bool(std::string)> streamer = [](std::string word) {
18+
19+
auto streamer = [](std::string word) {
1920
std::cout << word << std::flush;
2021
// Return flag corresponds whether generation should be stopped.
21-
// false means continue generation.
22-
return false;
22+
return ov::genai::StreamingStatus::RUNNING;
2323
};
2424

2525
pipe.start_chat();

samples/cpp/text_generation/multinomial_causal_lm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ int main(int argc, char* argv[]) try {
2121
config.top_k = 30;
2222
auto streamer = [](std::string subword) {
2323
std::cout << subword << std::flush;
24-
return false;
24+
return ov::genai::StreamingStatus::RUNNING;
2525
};
2626

2727
// Since the streamer is set, the results will

samples/cpp/text_generation/prompt_lookup_decoding_lm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ int main(int argc, char* argv[]) try {
2929

3030
auto streamer = [](std::string subword) {
3131
std::cout << subword << std::flush;
32-
return false;
32+
return ov::genai::StreamingStatus::RUNNING;
3333
};
3434

3535
// Since the streamer is set, the results will

samples/cpp/text_generation/speculative_decoding_lm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ int main(int argc, char* argv[]) try {
3333

3434
auto streamer = [](std::string subword) {
3535
std::cout << subword << std::flush;
36-
return false;
36+
return ov::genai::StreamingStatus::RUNNING;
3737
};
3838

3939
// Since the streamer is set, the results will

samples/python/text_generation/chat_sample.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
def streamer(subword):
1010
print(subword, end='', flush=True)
1111
# Return flag corresponds whether generation should be stopped.
12-
# False means continue generation.
13-
return False
14-
12+
return openvino_genai.StreamingStatus.RUNNING
1513

1614
def main():
1715
parser = argparse.ArgumentParser()

samples/python/text_generation/multinomial_causal_lm.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ def __next__(self):
5656

5757
def get_stop_flag(self):
5858
"""
59-
Checks whether the generation process should be stopped.
59+
Checks whether the generation process should be stopped or cancelled.
6060
6161
Returns:
62-
bool: Always returns False in this implementation.
62+
openvino_genai.StreamingStatus: Always returns RUNNING in this implementation.
6363
"""
64-
return False
64+
return openvino_genai.StreamingStatus.RUNNING
6565

6666
def put_word(self, word: str):
6767
"""
@@ -72,7 +72,7 @@ def put_word(self, word: str):
7272
"""
7373
self.text_queue.put(word)
7474

75-
def put(self, token_id: int) -> bool:
75+
def write(self, token_id: int) -> openvino_genai.StreamingStatus:
7676
"""
7777
Processes a token and manages the decoding buffer. Adds decoded text to the queue.
7878
@@ -106,12 +106,12 @@ def put(self, token_id: int) -> bool:
106106
self.print_len = print_until
107107
self.put_word(word)
108108

109-
if self.get_stop_flag():
109+
stop_flag = self.get_stop_flag()
110+
if stop_flag != openvino_genai.StreamingStatus.RUNNING:
110111
# When generation is stopped from streamer then end is not called, need to call it here manually.
111112
self.end()
112-
return True # True means stop generation
113-
else:
114-
return False # False means continue generation
113+
114+
return stop_flag
115115

116116
def end(self):
117117
"""
@@ -123,6 +123,7 @@ def end(self):
123123
self.put_word(word)
124124
self.tokens_cache = []
125125
self.print_len = 0
126+
self.put_word('\n')
126127
self.put_word(None)
127128

128129

@@ -132,12 +133,12 @@ def __init__(self, tokenizer, tokens_len):
132133
super().__init__(tokenizer)
133134
self.tokens_len = tokens_len
134135

135-
def put(self, token_id: int) -> bool:
136+
def write(self, token_id: int) -> openvino_genai.StreamingStatus:
136137
if (len(self.tokens_cache) + 1) % self.tokens_len != 0:
137138
self.tokens_cache.append(token_id)
138139
self.decoded_lengths.append(-1)
139-
return False
140-
return super().put(token_id)
140+
return openvino_genai.StreamingStatus.RUNNING
141+
return super().write(token_id)
141142

142143

143144
def main():

samples/python/text_generation/prompt_lookup_decoding_lm.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
import argparse
66
import openvino_genai
77

8-
def streamer(subword):
9-
print(subword, end='', flush=True)
10-
# Return flag corresponds whether generation should be stopped.
11-
# False means continue generation.
12-
return False
8+
def streamer(subword):
9+
print(subword, end='', flush=True)
10+
# Return flag corresponds whether generation should be stopped.
11+
return openvino_genai.StreamingStatus.RUNNING
1312

1413
def main():
1514
parser = argparse.ArgumentParser()

samples/python/text_generation/speculative_decoding_lm.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
def streamer(subword):
1010
print(subword, end='', flush=True)
1111
# Return flag corresponds whether generation should be stopped.
12-
# False means continue generation.
13-
return False
12+
return openvino_genai.StreamingStatus.RUNNING
1413

1514
def main():
1615
parser = argparse.ArgumentParser()

samples/python/visual_language_chat/visual_language_chat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def streamer(subword: str) -> bool:
2323
print(subword, end='', flush=True)
2424

2525
# No value is returned as in this example we don't want to stop the generation in this method.
26-
# "return None" will be treated the same as "return False".
26+
# "return None" will be treated the same as "return openvino_genai.StreamingStatus.RUNNING".
2727

2828

2929
def read_image(path: str) -> Tensor:

src/README.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ int main(int argc, char* argv[]) {
172172
auto streamer = [](std::string word) {
173173
std::cout << word << std::flush;
174174
// Return flag corresponds whether generation should be stopped.
175-
// false means continue generation.
176-
return false;
175+
return ov::genai::StreamingStatus::RUNNING;
177176
};
178177
std::cout << pipe.generate("The Sun is yellow because", ov::genai::streamer(streamer), ov::genai::max_new_tokens(200));
179178
}

src/cpp/include/openvino/genai/generation_handle.hpp

+19-5
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
#include "openvino/genai/perf_metrics.hpp"
1212

1313
namespace ov::genai {
14+
1415
enum class GenerationStatus {
1516
RUNNING = 0, // Default status for ongoing generation
1617
FINISHED = 1, // Status set when generation has been finished
1718
IGNORED = 2, // Status set when generation run into out-of-memory condition and could not be continued
18-
DROPPED_BY_PIPELINE = 3, // Currently not used, TODO: implement abort functionality
19-
DROPPED_BY_HANDLE = 4 // Status set when generation handle is dropped
19+
CANCEL = 3, // Status set when generation handle is cancelled. The last prompt and all generated tokens will be dropped from history, KV cache will include history but last step.
20+
STOP = 4, // Status set when generation handle is stopped. History will be kept, KV cache will include the last prompt and generated tokens.
21+
DROPPED_BY_HANDLE OPENVINO_ENUM_DEPRECATED("Please, use `STOP` instead of `DROPPED_BY_HANDLE`.") = GenerationStatus::STOP // Status set when generation handle is dropped.
2022
};
2123

24+
2225
struct EncodedGenerationResult {
2326
// request ID - obsolete when handle API is approved as handle will connect results with prompts.
2427
uint64_t m_request_id;
@@ -70,10 +73,10 @@ using GenerationOutputs = std::unordered_map<uint64_t, GenerationOutput>;
7073

7174
class GenerationStream;
7275

73-
class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
76+
class OPENVINO_GENAI_EXPORTS
77+
GenerationHandleImpl {
7478
std::shared_ptr<GenerationStream> m_generation_stream;
75-
ov::genai::GenerationConfig m_sampling_params;
76-
79+
ov::genai::GenerationConfig m_sampling_params;
7780
public:
7881
GenerationHandleImpl(std::shared_ptr<GenerationStream> generation_stream, const ov::genai::GenerationConfig& sampling_params) :
7982
m_generation_stream(std::move(generation_stream)),
@@ -88,10 +91,21 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
8891
GenerationStatus get_status();
8992

9093
bool can_read();
94+
95+
OPENVINO_DEPRECATED("Please, use `stop()` instead of `drop()`. Support will be removed in 2026.0.0 release.")
9196
bool is_dropped();
9297

98+
bool is_stopped();
99+
100+
bool is_cancelled();
101+
102+
OPENVINO_DEPRECATED("Please, use `stop()` instead of `drop()`. Support will be removed in 2026.0.0 release.")
93103
void drop();
94104

105+
void stop();
106+
107+
void cancel();
108+
95109
// Reads result of a generation for single iteration
96110
GenerationOutputs read();
97111
// Reads all generated tokens for all sequences

src/cpp/include/openvino/genai/llm_pipeline.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
namespace ov {
1919
namespace genai {
2020

21-
// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
22-
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
21+
// Return flag corresponds whether generation should be stopped. It could be:
22+
// ov::genai::StreamingStatus flag, RUNNING means continue generation, STOP means stop generation, CANCEL means stop generation and remove last propmt and answer from history
23+
// *DEPRECATED* bool flag, false means continue generation, true means stop. Please, use `ov::genai::StreamingStatus` instead.
24+
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<StreamingStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
2325
using OptionalGenerationConfig = std::optional<GenerationConfig>;
2426
using EncodedInputs = std::variant<ov::Tensor, TokenizedInputs>;
2527
using StringInputs = std::variant<std::string, std::vector<std::string>>;

src/cpp/include/openvino/genai/streamer_base.hpp

+19-2
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,37 @@
44
#pragma once
55

66
#include "openvino/genai/tokenizer.hpp"
7+
#include <variant>
78

89
namespace ov {
910
namespace genai {
1011

12+
enum class StreamingStatus {
13+
RUNNING = 0, // Continue to run of inference
14+
STOP = 1, // Stop generation, keep history as is, KV cache includes last request and generated tokens
15+
CANCEL = 2 // Stop generate, drop last prompt and all generated tokens from history, KV cache includes history but last step
16+
};
17+
1118
/**
1219
* @brief base class for streamers. In order to use inherit from from this class and implement put, and methods
1320
*
1421
* @param m_tokenizer tokenizer
1522
*/
1623
class OPENVINO_GENAI_EXPORTS StreamerBase {
1724
public:
18-
/// @brief put is called every time new token is decoded,
25+
/// @brief put is called every time new token is decoded. Deprecated. Please, use write instead.
1926
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
20-
virtual bool put(int64_t token) = 0;
27+
OPENVINO_DEPRECATED("Please, use `write()` instead of `put()`. Support will be removed in 2026.0.0 release.")
28+
virtual bool put(int64_t token) {
29+
OPENVINO_THROW("This method is deprecated and will be removed in 2026.0.0 release. Please, override write() insted.");
30+
return true;
31+
};
32+
33+
/// @brief write is called every time new token is decoded
34+
/// @return StreamingStatus flag to indicate whether generation should be countinue to run or stopped or cancelled
35+
virtual StreamingStatus write(int64_t token) {
36+
return put(token) ? StreamingStatus::STOP : StreamingStatus::RUNNING;
37+
};
2138

2239
/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
2340
virtual void end() = 0;

src/cpp/include/openvino/genai/whisper_pipeline.hpp

+12-3
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,23 @@ using RawSpeechInput = std::vector<float>;
2525
*/
2626
class OPENVINO_GENAI_EXPORTS ChunkStreamerBase : public StreamerBase {
2727
public:
28-
/// @brief put is called every time new token chunk is generated,
28+
/// @brief put_chunk is called every time new token chunk is generated,
2929
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
30-
virtual bool put_chunk(std::vector<int64_t> tokens) = 0;
30+
virtual bool put_chunk(std::vector<int64_t> tokens) {
31+
OPENVINO_THROW("This method is deprecated and will be removed in 2026.0.0 release. Please, override write_chunk() insted.");
32+
return true;
33+
}
34+
35+
/// @brief write_chunk is called every time new token chunk is generated
36+
/// @return StreamingStatus flag to indicate whether generation should be stopped
37+
virtual StreamingStatus write_chunk(std::vector<int64_t> tokens) {
38+
return put_chunk(tokens) ? StreamingStatus::STOP : StreamingStatus::RUNNING;
39+
}
3140
};
3241

3342
// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
3443
using ChunkStreamerVariant =
35-
std::variant<std::function<bool(std::string)>, std::shared_ptr<ChunkStreamerBase>, std::monostate>;
44+
std::variant<std::function<bool(std::string)>, std::function<StreamingStatus(std::string)>, std::shared_ptr<ChunkStreamerBase>, std::monostate>;
3645

3746
struct OPENVINO_GENAI_EXPORTS WhisperRawPerfMetrics {
3847
/** @brief Duration for each features extraction call */

src/cpp/src/continuous_batching_adapter.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
9090
std::vector<std::string> plain_replies;
9191
std::vector<float> plain_scores;
9292
for (GenerationResult& res : generated) {
93-
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
93+
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP || res.m_status == GenerationStatus::CANCEL, "Got unfinished GenerationStatus");
9494
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_replies));
9595
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
9696
}
@@ -182,7 +182,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
182182
std::vector<std::vector<int64_t>> plain_tokens;
183183
std::vector<float> plain_scores;
184184
for (EncodedGenerationResult& res : generated) {
185-
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
185+
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP || res.m_status == GenerationStatus::CANCEL, "Got unfinished GenerationStatus");
186186
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_tokens));
187187
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
188188
}

src/cpp/src/continuous_batching_impl.cpp

+7-15
Original file line numberDiff line numberDiff line change
@@ -432,17 +432,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
432432
}
433433
set_adapters(sampling_params[0].adapters);
434434

435-
const std::shared_ptr<StreamerBase>& streamer_ptr = std::visit(overloaded{
436-
[](std::monostate) -> std::shared_ptr<StreamerBase> {
437-
return nullptr;
438-
},
439-
[](const std::shared_ptr<StreamerBase>& streamer) {
440-
return streamer;
441-
},
442-
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
443-
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
444-
}
445-
}, streamer);
435+
const std::shared_ptr<StreamerBase>& streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer);
446436

447437
OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && sampling_params[0].num_return_sequences == 1 &&
448438
(sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()),
@@ -479,8 +469,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
479469
OPENVINO_ASSERT(generation_outputs.size() <= 1);
480470
if (!generation_outputs.empty()) {
481471
for (const auto& generated_token_id : generation_outputs.begin()->second.generated_ids) {
482-
if (streamer_ptr->put(generated_token_id)) {
483-
generation->drop();
472+
auto streaming_status = streamer_ptr->write(generated_token_id);
473+
if (streaming_status != ov::genai::StreamingStatus::RUNNING) {
474+
streaming_status == ov::genai::StreamingStatus::CANCEL ? generation->cancel() : generation->stop();
484475
break;
485476
}
486477
}
@@ -540,6 +531,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
540531
result.m_request_id = request_id;
541532
result.m_generation_ids.resize(num_outputs);
542533
result.m_scores.resize(num_outputs);
534+
result.m_status = request->get_generation_stream()->get_status();
543535

544536
for (size_t i = 0; i < num_outputs; ++i) {
545537
const auto & sequence = sequences[i];
@@ -574,7 +566,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_free_non_running_reque
574566
std::vector<SequenceGroup::Ptr>::iterator requests_iterator = m_requests.begin();
575567
while (requests_iterator != m_requests.end()) {
576568
const auto& request = *requests_iterator;
577-
if (request->has_finished() || request->handle_dropped()) {
569+
if(request->has_finished() || request->handle_stopped() || request->handle_cancelled()) {
578570
for (const auto& sequence: request->get_sequences()) {
579571
if (m_scheduler->has_block_table(sequence->get_id())) {
580572
m_scheduler->free_sequence(sequence->get_id());
@@ -592,7 +584,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_notify_requests_droppe
592584
// Notify the last time by pushing empty output
593585
// This causes read() to unblock by adding anything to the queue
594586
for (SequenceGroup::Ptr& request : m_requests) {
595-
if (request->handle_dropped())
587+
if (request->handle_stopped() || request->handle_cancelled())
596588
request->push_empty_outputs();
597589
}
598590
}

0 commit comments

Comments
 (0)