Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f997c42

Browse files
committedSep 25, 2024
Move methods to sampler.cpp & move private fields from interface to impl
1 parent e63ccda commit f997c42

File tree

4 files changed

+363
-346
lines changed

4 files changed

+363
-346
lines changed
 

‎src/cpp/src/continuous_batching_impl.hpp

+12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@
99
namespace ov::genai {
1010
class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatchingPipeline::ImplInterface {
1111
protected:
12+
std::shared_ptr<Scheduler> m_scheduler;
13+
std::shared_ptr<CacheManager> m_cache_manager;
14+
std::shared_ptr<ModelRunner> m_model_runner;
15+
std::shared_ptr<Sampler> m_sampler;
16+
17+
// current requests to process
18+
std::vector<SequenceGroup::Ptr> m_requests;
19+
// requests added to the pipeline that will be added to m_requests in the next iteration
20+
std::vector<SequenceGroup::Ptr> m_awaiting_requests;
21+
// Mutex protecting access to m_awaiting_requests, so add_request and step methods can be called from different threads
22+
std::mutex m_awaiting_requests_mutex;
23+
1224
void _free_non_running_requests();
1325
void _notify_requests_dropped_by_handle();
1426

‎src/cpp/src/continuous_batching_impl_interface.hpp

-11
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ namespace ov::genai {
1515
class ContinuousBatchingPipeline::ImplInterface {
1616
protected:
1717
Tokenizer m_tokenizer;
18-
std::shared_ptr<Scheduler> m_scheduler;
19-
std::shared_ptr<CacheManager> m_cache_manager;
20-
std::shared_ptr<ModelRunner> m_model_runner;
21-
std::shared_ptr<Sampler> m_sampler;
2218

2319
// TODO (mzegla): GenerationConfig is request specific object
2420
// and pipeline only uses default rng_seed.
@@ -39,13 +35,6 @@ class ContinuousBatchingPipeline::ImplInterface {
3935
std::cout << std::endl;
4036
}
4137
} m_perf;
42-
43-
// current requests to process
44-
std::vector<SequenceGroup::Ptr> m_requests;
45-
// requests added to the pipeline that will be added to m_requests in the next iteration
46-
std::vector<SequenceGroup::Ptr> m_awaiting_requests;
47-
// Mutex protecting access to m_awaiting_requests, so add_request and step methods can be called from different threads
48-
std::mutex m_awaiting_requests_mutex;
4938
bool m_is_chat_conversation = false;
5039
ChatHistory m_history;
5140

‎src/cpp/src/sampler.cpp

+338
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,206 @@
55
#include "sampler.hpp"
66

77
namespace ov::genai {
8+
std::vector<int64_t> kmp_search(const std::vector<int64_t>& haystack, const std::vector<int64_t>& needle) {
9+
if (needle.empty()) { // no_repeat_ngram_size == 1, ban every token
10+
return {haystack.begin(), haystack.end()};
11+
}
12+
std::vector<int> partial_match_table(needle.size() + 1, -1);
13+
int cnd = 0;
14+
for (size_t pos = 1; pos < needle.size(); ++pos) {
15+
if (needle.at(pos) == needle.at(size_t(cnd))) {
16+
partial_match_table.at(pos) = partial_match_table.at(size_t(cnd));
17+
} else {
18+
partial_match_table.at(pos) = cnd;
19+
while (cnd >= 0 && needle.at(pos) != needle.at(size_t(cnd))) {
20+
cnd = partial_match_table.at(size_t(cnd));
21+
}
22+
}
23+
++cnd;
24+
}
25+
partial_match_table.back() = cnd;
26+
std::vector<int64_t> res;
27+
size_t haystack_id = 0;
28+
int needle_id = 0;
29+
while (haystack_id < haystack.size() - 1) {
30+
if (needle.at(size_t(needle_id)) == haystack.at(haystack_id)) {
31+
++haystack_id;
32+
++needle_id;
33+
if (needle_id == int(needle.size())) {
34+
res.push_back(haystack.at(haystack_id));
35+
needle_id = partial_match_table.at(size_t(needle_id));
36+
}
37+
} else {
38+
needle_id = partial_match_table.at(size_t(needle_id));
39+
if (needle_id < 0) {
40+
++haystack_id;
41+
++needle_id;
42+
}
43+
}
44+
}
45+
return res;
46+
}
47+
48+
std::vector<Token> log_softmax(const ov::Tensor& logits, size_t batch_idx) {
49+
ov::Shape shape = logits.get_shape();
50+
OPENVINO_ASSERT(shape.size() == 3);
51+
size_t batch = shape[0], seq_len = shape[1], vocab_size = shape[2];
52+
OPENVINO_ASSERT(batch_idx < batch, "Logits batch size doesn't match the number of beams");
53+
54+
size_t batch_offset = batch_idx * seq_len * vocab_size, sequence_offset = (seq_len - 1) * vocab_size;
55+
const float* beam_logits = logits.data<const float>() + batch_offset + sequence_offset;
56+
float max_logit = *std::max_element(beam_logits, beam_logits + vocab_size);
57+
float log_sum = std::log(std::accumulate(
58+
beam_logits, beam_logits + vocab_size, 0.0f, [max_logit](float accumulated, float to_add) {
59+
return accumulated + std::exp(to_add - max_logit);
60+
}));
61+
62+
std::vector<Token> tokens;
63+
tokens.reserve(vocab_size);
64+
for (size_t idx = 0; idx < vocab_size; ++idx)
65+
tokens.push_back({beam_logits[idx] - max_logit - log_sum, int64_t(idx)});
66+
67+
return tokens;
68+
}
69+
70+
std::vector<int64_t> wrap_tokens(const std::vector<int64_t>& tokens, const std::vector<int64_t>& prefix_tokens, const std::vector<int64_t>& suffix_tokens) {
71+
std::vector<int64_t> all_tokens = prefix_tokens;
72+
all_tokens.insert(all_tokens.end(), tokens.begin(), tokens.end());
73+
all_tokens.insert(all_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
74+
return all_tokens;
75+
}
76+
77+
std::string clean_wrapped_text(const std::string& wrapped_text, const std::string& prefix, const std::string& suffix) {
78+
auto prefix_pos = wrapped_text.find(prefix);
79+
OPENVINO_ASSERT(prefix_pos != std::string::npos);
80+
auto suffix_pos = wrapped_text.rfind(suffix);
81+
OPENVINO_ASSERT(suffix_pos != std::string::npos);
82+
auto clean_text_start = prefix_pos + prefix.size();
83+
auto clean_text_length = suffix_pos - clean_text_start;
84+
std::string clean_text = wrapped_text.substr(clean_text_start, clean_text_length);
85+
return clean_text;
86+
}
87+
88+
int match_stop_string(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings) {
89+
/*
90+
For catching stop_string hit we run comparisons character-wise to catch cases where stop string
91+
overlaps with part of another token on both sides or is just a part of a single token.
92+
For every stop_string we iterate over generated tokens starting from the last one and going backwards.
93+
Every token is wrapped with prefix tokens to ensure tokenizer doesn't remove prefix whitespace of the actual token.
94+
After that all tokens are decoded and prefix is removed from the decoded text, so we end up with decoded token.
95+
Its characters are compared to the stop_string character at a current_position
96+
(position of a character in the stop_string counting from the last one) - at the begining position is 0.
97+
When characters match we increase current_position and check if we have a full match already, if not we continue.
98+
If we have already matched some characters (current_position > 0) and next character is not matching
99+
before we reach the full match, then we reset current_position to 0.
100+
*/
101+
std::string prefix = "a";
102+
auto prefix_ov = tokenizer.encode(prefix).input_ids;
103+
std::vector<int64_t> prefix_tokens(prefix_ov.data<int64_t>(), prefix_ov.data<int64_t>() + prefix_ov.get_size());
104+
std::string suffix = "b";
105+
auto suffix_ov = tokenizer.encode(suffix).input_ids;
106+
std::vector<int64_t> suffix_tokens(suffix_ov.data<int64_t>(), suffix_ov.data<int64_t>() + suffix_ov.get_size());
107+
108+
// Since whitespace can be added at the beginning of the suffix we also try to capture that behavior here
109+
// and get suffix string that will actually be part of the decoded string so we can remove it correctly
110+
auto wrapped_suffix_tokens = suffix_tokens;
111+
wrapped_suffix_tokens.insert(wrapped_suffix_tokens.begin(), prefix_tokens.begin(), prefix_tokens.end());
112+
std::string wrapped_suffix = tokenizer.decode(wrapped_suffix_tokens);
113+
auto wrapper_pos = wrapped_suffix.find(prefix);
114+
suffix = wrapped_suffix.substr(wrapper_pos + prefix.size());
115+
116+
for (auto stop_string: stop_strings) {
117+
int current_position = 0;
118+
int num_matched_tokens = 0;
119+
// Getting reverse iterator to check tokens starting from the last one generated and going backwards
120+
auto generated_tokens_rit = generated_tokens.rbegin();
121+
std::vector<int64_t> tokens_buffer;
122+
while (generated_tokens_rit != generated_tokens.rend()) {
123+
num_matched_tokens++;
124+
tokens_buffer.insert(tokens_buffer.begin(), *generated_tokens_rit);
125+
126+
std::vector<int64_t> wrapped_tokens = wrap_tokens(tokens_buffer, prefix_tokens, suffix_tokens);
127+
std::string wrapped_text = tokenizer.decode(wrapped_tokens);
128+
std::string clean_text = clean_wrapped_text(wrapped_text, prefix, suffix);
129+
130+
if (clean_text == "" || (clean_text.size() >= 3 && (clean_text.compare(clean_text.size() - 3, 3, "") == 0))) {
131+
generated_tokens_rit++;
132+
continue;
133+
} else {
134+
tokens_buffer.clear();
135+
}
136+
// Checking clean_text characters starting from the last one
137+
for (auto clean_text_rit = clean_text.rbegin(); clean_text_rit != clean_text.rend(); clean_text_rit++) {
138+
// On character match increment current_position for the next comparisons
139+
if (*clean_text_rit == *(stop_string.rbegin() + current_position)) {
140+
current_position++;
141+
// If this is the last character from the stop_string we have a match
142+
if ((stop_string.rbegin() + current_position) == stop_string.rend()) {
143+
return num_matched_tokens;
144+
}
145+
} else if (current_position) {
146+
// Already found matching characters, but the last one didn't match, so we reset current_position
147+
current_position = 0;
148+
// Looking for the match will start over from this character so we decrement iterator
149+
clean_text_rit--;
150+
}
151+
}
152+
generated_tokens_rit++;
153+
}
154+
}
155+
return 0;
156+
}
157+
158+
int match_stop_string2(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings) {
159+
for (auto stop_string: stop_strings) {
160+
auto stop_tokens_ov = tokenizer.encode(stop_string).input_ids;
161+
size_t num_tokens = stop_tokens_ov.get_size();
162+
if(num_tokens > generated_tokens.size())
163+
continue;
164+
165+
// Check direct token match
166+
std::vector<int64_t> stop_tokens(stop_tokens_ov.data<int64_t>(), stop_tokens_ov.data<int64_t>() + num_tokens);
167+
std::vector<int64_t> last_generated_tokens(generated_tokens.end()-num_tokens, generated_tokens.end());
168+
if (stop_tokens == last_generated_tokens)
169+
return num_tokens;
170+
171+
// Continue checking chunks of 4 tokens
172+
num_tokens += 4;
173+
while (num_tokens <= generated_tokens.size()) {
174+
std::vector<int64_t> last_generated_tokens(generated_tokens.end()-num_tokens, generated_tokens.end());
175+
std::string decoded_last_tokens = tokenizer.decode(last_generated_tokens);
176+
if (decoded_last_tokens.find(stop_string) != std::string::npos) {
177+
return num_tokens;
178+
}
179+
num_tokens += 4;
180+
}
181+
}
182+
return 0;
183+
}
184+
185+
void GroupBeamSearcher::finalize(SamplerOutput& sampler_output) {
186+
for (Group& group : m_groups) {
187+
if (!group.done) {
188+
for (Beam& beam : group.ongoing) {
189+
uint64_t sequence_id = beam.m_sequence->get_id();
190+
191+
int64_t preempted_id = group.finish(beam, m_parameters);
192+
if (preempted_id >= 0) {
193+
// remove preempted one
194+
m_sequence_group->remove_sequence(preempted_id);
195+
}
196+
197+
// mark current sequence as finished
198+
beam.m_sequence->set_status(SequenceStatus::FINISHED);
199+
// Setting length since this function is used when sequence generated tokens number reaches max_new_tokens
200+
beam.m_sequence->set_finish_reason(GenerationFinishReason::LENGTH);
201+
// we also need to drop add ongoing / forked sequences from scheduler
202+
sampler_output.m_dropped_sequences.push_back(sequence_id);
203+
}
204+
}
205+
}
206+
}
207+
8208
GroupBeamSearcher::GroupBeamSearcher(SequenceGroup::Ptr sequence_group, Tokenizer tokenizer)
9209
: m_sequence_group(sequence_group),
10210
m_parameters{m_sequence_group->get_sampling_parameters()},
@@ -261,6 +461,96 @@ void GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutp
261461
}
262462
}
263463

464+
Logits Sampler::_get_logit_vector(ov::Tensor logits, size_t batch_idx) {
465+
ov::Shape logits_shape = logits.get_shape();
466+
size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2];
467+
OPENVINO_ASSERT(batch_idx <= batch_size);
468+
size_t batch_offset = batch_idx * seq_len * vocab_size;
469+
size_t sequence_offset = (seq_len - 1) * vocab_size;
470+
float* logits_data = logits.data<float>() + batch_offset + sequence_offset;
471+
472+
return Logits{logits_data, vocab_size};
473+
}
474+
475+
Token Sampler::_greedy_sample(const Logits& logits) const {
476+
// For greedy sampling we do not expect sorting or shrinking considered tokens
477+
// so we can operate directly on the data buffer
478+
float max_value = -std::numeric_limits<float>::infinity();
479+
size_t max_index = 0;
480+
for (size_t i = 0; i < logits.m_size; ++i) {
481+
if (logits.m_data[i] > max_value) {
482+
max_value = logits.m_data[i];
483+
max_index = i;
484+
}
485+
}
486+
487+
// apply log softmax to max value
488+
float log_sum = std::log(std::accumulate(
489+
logits.m_data, logits.m_data + logits.m_size, 0.0f, [max_value](float accumulated, float to_add) {
490+
return accumulated + std::exp(to_add - max_value);
491+
}));
492+
max_value = -log_sum;
493+
494+
return Token(max_value, max_index);
495+
}
496+
497+
std::vector<Token> Sampler::_multinomial_sample(const Logits& logits, size_t num_tokens_per_sequence) {
498+
// If top_p or top_k was applied we use sorted vector, if not we go with original buffer.
499+
std::vector<float> multinomial_weights;
500+
multinomial_weights.reserve(logits.m_size);
501+
if (logits.is_vector_initialized())
502+
for (auto& logit: logits.m_vector) multinomial_weights.emplace_back(logit.m_log_prob);
503+
else
504+
multinomial_weights.assign(logits.m_data, logits.m_data + logits.m_size);
505+
506+
auto dist = std::discrete_distribution<size_t>(multinomial_weights.begin(), multinomial_weights.end()); // equivalent to multinomial with number of trials == 1
507+
508+
std::vector<Token> out_tokens;
509+
for (size_t token_idx = 0; token_idx < num_tokens_per_sequence; ++token_idx) {
510+
size_t element_to_pick = dist(rng_engine);
511+
if (logits.is_vector_initialized())
512+
out_tokens.push_back(logits.m_vector[element_to_pick]);
513+
else
514+
out_tokens.emplace_back(logits.m_data[element_to_pick], element_to_pick);
515+
}
516+
return out_tokens;
517+
}
518+
519+
std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequence_group) {
520+
auto sampling_params = sequence_group->get_sampling_parameters();
521+
std::vector<int64_t> dropped_seq_ids;
522+
for (auto& running_sequence : sequence_group->get_running_sequences()) {
523+
const auto generated_len = running_sequence->get_generated_len();
524+
if (sampling_params.max_new_tokens == generated_len ||
525+
is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
526+
// stop sequence by max_new_tokens or stop token (eos included)
527+
running_sequence->set_status(SequenceStatus::FINISHED);
528+
529+
if (is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
530+
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
531+
} else if (sampling_params.max_new_tokens == generated_len) {
532+
running_sequence->set_finish_reason(GenerationFinishReason::LENGTH);
533+
}
534+
535+
dropped_seq_ids.push_back(running_sequence->get_id());
536+
continue;
537+
}
538+
539+
if (!sampling_params.stop_strings.empty()) {
540+
int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings);
541+
if (num_matched_last_tokens) {
542+
if (!sampling_params.include_stop_str_in_output)
543+
running_sequence->remove_last_tokens(num_matched_last_tokens);
544+
running_sequence->set_status(SequenceStatus::FINISHED);
545+
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
546+
dropped_seq_ids.push_back(running_sequence->get_id());
547+
}
548+
}
549+
}
550+
return dropped_seq_ids;
551+
}
552+
553+
264554
SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits) {
265555
const float * logits_data = logits.data<float>();
266556
ov::Shape logits_shape = logits.get_shape();
@@ -370,5 +660,53 @@ void Sampler::clear_beam_search_info(uint64_t request_id) {
370660
m_beam_search_info.erase(request_id);
371661
}
372662

663+
int64_t Group::finish(Beam beam, const ov::genai::GenerationConfig& sampling_params) {
664+
int64_t preeempted_sequence_id = -1;
665+
float generated_len = beam.get_generated_len() + (is_stop_token_id_hit(beam.m_token_id, sampling_params.stop_token_ids) ? 1 : 0); // HF counts EOS token in generation length
666+
beam.m_score /= std::pow(generated_len, sampling_params.length_penalty);
373667

668+
min_heap.push_back(beam);
669+
std::push_heap(min_heap.begin(), min_heap.end(), greater);
670+
assert(sampling_params.num_beams % sampling_params.num_beam_groups == 0 &&
671+
"number of beams should be divisible by number of groups");
672+
size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups;
673+
if (min_heap.size() > group_size) {
674+
std::pop_heap(min_heap.begin(), min_heap.end(), greater);
675+
preeempted_sequence_id = min_heap.back().m_sequence->get_id();
676+
min_heap.pop_back();
677+
}
678+
679+
return preeempted_sequence_id;
680+
}
681+
682+
void Group::is_done(const ov::genai::GenerationConfig& sampling_params) {
683+
assert(sampling_params.num_beams % sampling_params.num_beam_groups == 0 &&
684+
"number of beams should be divisible by number of groups");
685+
size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups;
686+
if (min_heap.size() < group_size)
687+
return;
688+
689+
const Beam& best_running_sequence = ongoing.front(), & worst_finished_sequence = min_heap.front();
690+
size_t cur_len = best_running_sequence.m_sequence->get_generated_len();
691+
float best_sum_logprobs = best_running_sequence.m_score;
692+
float worst_score = worst_finished_sequence.m_score;
693+
switch (sampling_params.stop_criteria) {
694+
case ov::genai::StopCriteria::EARLY:
695+
done = true;
696+
return;
697+
case ov::genai::StopCriteria::HEURISTIC: {
698+
float highest_attainable_score = best_sum_logprobs / std::pow(float(cur_len), sampling_params.length_penalty);
699+
done = worst_score >= highest_attainable_score;
700+
return;
701+
}
702+
case ov::genai::StopCriteria::NEVER: {
703+
size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.max_new_tokens : cur_len;
704+
float highest_attainable_score = best_sum_logprobs / std::pow(float(length), sampling_params.length_penalty);
705+
done = worst_score >= highest_attainable_score;
706+
return;
707+
}
708+
default:
709+
OPENVINO_THROW("Beam search internal error: unkown mode");
710+
}
711+
}
374712
}

‎src/cpp/src/sampler.hpp

+13-335
Original file line numberDiff line numberDiff line change
@@ -22,189 +22,21 @@
2222

2323
namespace ov::genai {
2424
// Modifyed Knuth–Morris–Pratt algorithm which returns tokens following after every needle occurance in haystack
25-
inline std::vector<int64_t> kmp_search(const std::vector<int64_t>& haystack, const std::vector<int64_t>& needle) {
26-
if (needle.empty()) { // no_repeat_ngram_size == 1, ban every token
27-
return {haystack.begin(), haystack.end()};
28-
}
29-
std::vector<int> partial_match_table(needle.size() + 1, -1);
30-
int cnd = 0;
31-
for (size_t pos = 1; pos < needle.size(); ++pos) {
32-
if (needle.at(pos) == needle.at(size_t(cnd))) {
33-
partial_match_table.at(pos) = partial_match_table.at(size_t(cnd));
34-
} else {
35-
partial_match_table.at(pos) = cnd;
36-
while (cnd >= 0 && needle.at(pos) != needle.at(size_t(cnd))) {
37-
cnd = partial_match_table.at(size_t(cnd));
38-
}
39-
}
40-
++cnd;
41-
}
42-
partial_match_table.back() = cnd;
43-
std::vector<int64_t> res;
44-
size_t haystack_id = 0;
45-
int needle_id = 0;
46-
while (haystack_id < haystack.size() - 1) {
47-
if (needle.at(size_t(needle_id)) == haystack.at(haystack_id)) {
48-
++haystack_id;
49-
++needle_id;
50-
if (needle_id == int(needle.size())) {
51-
res.push_back(haystack.at(haystack_id));
52-
needle_id = partial_match_table.at(size_t(needle_id));
53-
}
54-
} else {
55-
needle_id = partial_match_table.at(size_t(needle_id));
56-
if (needle_id < 0) {
57-
++haystack_id;
58-
++needle_id;
59-
}
60-
}
61-
}
62-
return res;
63-
}
25+
std::vector<int64_t> kmp_search(const std::vector<int64_t>& haystack, const std::vector<int64_t>& needle);
6426

65-
inline std::vector<Token> log_softmax(const ov::Tensor& logits, size_t batch_idx) {
66-
ov::Shape shape = logits.get_shape();
67-
OPENVINO_ASSERT(shape.size() == 3);
68-
size_t batch = shape[0], seq_len = shape[1], vocab_size = shape[2];
69-
OPENVINO_ASSERT(batch_idx < batch, "Logits batch size doesn't match the number of beams");
70-
71-
size_t batch_offset = batch_idx * seq_len * vocab_size, sequence_offset = (seq_len - 1) * vocab_size;
72-
const float* beam_logits = logits.data<const float>() + batch_offset + sequence_offset;
73-
float max_logit = *std::max_element(beam_logits, beam_logits + vocab_size);
74-
float log_sum = std::log(std::accumulate(
75-
beam_logits, beam_logits + vocab_size, 0.0f, [max_logit](float accumulated, float to_add) {
76-
return accumulated + std::exp(to_add - max_logit);
77-
}));
78-
79-
std::vector<Token> tokens;
80-
tokens.reserve(vocab_size);
81-
for (size_t idx = 0; idx < vocab_size; ++idx)
82-
tokens.push_back({beam_logits[idx] - max_logit - log_sum, int64_t(idx)});
83-
84-
return tokens;
85-
}
27+
std::vector<Token> log_softmax(const ov::Tensor& logits, size_t batch_idx);
8628

87-
inline std::vector<int64_t>
88-
wrap_tokens(const std::vector<int64_t>& tokens, const std::vector<int64_t>& prefix_tokens, const std::vector<int64_t>& suffix_tokens) {
89-
std::vector<int64_t> all_tokens = prefix_tokens;
90-
all_tokens.insert(all_tokens.end(), tokens.begin(), tokens.end());
91-
all_tokens.insert(all_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
92-
return all_tokens;
93-
}
29+
std::vector<int64_t> wrap_tokens(const std::vector<int64_t>& tokens, const std::vector<int64_t>& prefix_tokens, const std::vector<int64_t>& suffix_tokens);
9430

95-
inline std::string clean_wrapped_text(const std::string& wrapped_text, const std::string& prefix, const std::string& suffix) {
96-
auto prefix_pos = wrapped_text.find(prefix);
97-
OPENVINO_ASSERT(prefix_pos != std::string::npos);
98-
auto suffix_pos = wrapped_text.rfind(suffix);
99-
OPENVINO_ASSERT(suffix_pos != std::string::npos);
100-
auto clean_text_start = prefix_pos + prefix.size();
101-
auto clean_text_length = suffix_pos - clean_text_start;
102-
std::string clean_text = wrapped_text.substr(clean_text_start, clean_text_length);
103-
return clean_text;
104-
}
31+
std::string clean_wrapped_text(const std::string& wrapped_text, const std::string& prefix, const std::string& suffix);
10532

10633
// Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned.
107-
inline int
108-
match_stop_string(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings) {
109-
/*
110-
For catching stop_string hit we run comparisons character-wise to catch cases where stop string
111-
overlaps with part of another token on both sides or is just a part of a single token.
112-
For every stop_string we iterate over generated tokens starting from the last one and going backwards.
113-
Every token is wrapped with prefix tokens to ensure tokenizer doesn't remove prefix whitespace of the actual token.
114-
After that all tokens are decoded and prefix is removed from the decoded text, so we end up with decoded token.
115-
Its characters are compared to the stop_string character at a current_position
116-
(position of a character in the stop_string counting from the last one) - at the begining position is 0.
117-
When characters match we increase current_position and check if we have a full match already, if not we continue.
118-
If we have already matched some characters (current_position > 0) and next character is not matching
119-
before we reach the full match, then we reset current_position to 0.
120-
*/
121-
std::string prefix = "a";
122-
auto prefix_ov = tokenizer.encode(prefix).input_ids;
123-
std::vector<int64_t> prefix_tokens(prefix_ov.data<int64_t>(), prefix_ov.data<int64_t>() + prefix_ov.get_size());
124-
std::string suffix = "b";
125-
auto suffix_ov = tokenizer.encode(suffix).input_ids;
126-
std::vector<int64_t> suffix_tokens(suffix_ov.data<int64_t>(), suffix_ov.data<int64_t>() + suffix_ov.get_size());
127-
128-
// Since whitespace can be added at the beginning of the suffix we also try to capture that behavior here
129-
// and get suffix string that will actually be part of the decoded string so we can remove it correctly
130-
auto wrapped_suffix_tokens = suffix_tokens;
131-
wrapped_suffix_tokens.insert(wrapped_suffix_tokens.begin(), prefix_tokens.begin(), prefix_tokens.end());
132-
std::string wrapped_suffix = tokenizer.decode(wrapped_suffix_tokens);
133-
auto wrapper_pos = wrapped_suffix.find(prefix);
134-
suffix = wrapped_suffix.substr(wrapper_pos + prefix.size());
135-
136-
for (auto stop_string: stop_strings) {
137-
int current_position = 0;
138-
int num_matched_tokens = 0;
139-
// Getting reverse iterator to check tokens starting from the last one generated and going backwards
140-
auto generated_tokens_rit = generated_tokens.rbegin();
141-
std::vector<int64_t> tokens_buffer;
142-
while (generated_tokens_rit != generated_tokens.rend()) {
143-
num_matched_tokens++;
144-
tokens_buffer.insert(tokens_buffer.begin(), *generated_tokens_rit);
145-
146-
std::vector<int64_t> wrapped_tokens = wrap_tokens(tokens_buffer, prefix_tokens, suffix_tokens);
147-
std::string wrapped_text = tokenizer.decode(wrapped_tokens);
148-
std::string clean_text = clean_wrapped_text(wrapped_text, prefix, suffix);
149-
150-
if (clean_text == "" || (clean_text.size() >= 3 && (clean_text.compare(clean_text.size() - 3, 3, "") == 0))) {
151-
generated_tokens_rit++;
152-
continue;
153-
} else {
154-
tokens_buffer.clear();
155-
}
156-
// Checking clean_text characters starting from the last one
157-
for (auto clean_text_rit = clean_text.rbegin(); clean_text_rit != clean_text.rend(); clean_text_rit++) {
158-
// On character match increment current_position for the next comparisons
159-
if (*clean_text_rit == *(stop_string.rbegin() + current_position)) {
160-
current_position++;
161-
// If this is the last character from the stop_string we have a match
162-
if ((stop_string.rbegin() + current_position) == stop_string.rend()) {
163-
return num_matched_tokens;
164-
}
165-
} else if (current_position) {
166-
// Already found matching characters, but the last one didn't match, so we reset current_position
167-
current_position = 0;
168-
// Looking for the match will start over from this character so we decrement iterator
169-
clean_text_rit--;
170-
}
171-
}
172-
generated_tokens_rit++;
173-
}
174-
}
175-
return 0;
176-
}
34+
int match_stop_string(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings);
17735

17836
// Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned.
17937
// Number of tokens might not be exact as if there's no direct token match, we decode generated tokens incrementally expanding decoding scope
18038
// with 4 next tokens with each iteration until we check all tokens.
181-
inline int match_stop_string2(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings) {
182-
for (auto stop_string: stop_strings) {
183-
auto stop_tokens_ov = tokenizer.encode(stop_string).input_ids;
184-
size_t num_tokens = stop_tokens_ov.get_size();
185-
if(num_tokens > generated_tokens.size())
186-
continue;
187-
188-
// Check direct token match
189-
std::vector<int64_t> stop_tokens(stop_tokens_ov.data<int64_t>(), stop_tokens_ov.data<int64_t>() + num_tokens);
190-
std::vector<int64_t> last_generated_tokens(generated_tokens.end()-num_tokens, generated_tokens.end());
191-
if (stop_tokens == last_generated_tokens)
192-
return num_tokens;
193-
194-
// Continue checking chunks of 4 tokens
195-
num_tokens += 4;
196-
while (num_tokens <= generated_tokens.size()) {
197-
std::vector<int64_t> last_generated_tokens(generated_tokens.end()-num_tokens, generated_tokens.end());
198-
std::string decoded_last_tokens = tokenizer.decode(last_generated_tokens);
199-
if (decoded_last_tokens.find(stop_string) != std::string::npos) {
200-
return num_tokens;
201-
}
202-
num_tokens += 4;
203-
}
204-
}
205-
206-
return 0;
207-
}
39+
int match_stop_string2(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings);
20840

20941
// Handle stop_token_ids
21042
inline bool is_stop_token_id_hit(int64_t generated_token, const std::set<int64_t> & stop_token_ids) {
@@ -243,55 +75,8 @@ struct Group {
24375
std::vector<Beam> min_heap; // The worst of the best completed beams is the first
24476
bool done = false;
24577

246-
int64_t finish(Beam beam, const ov::genai::GenerationConfig& sampling_params) {
247-
int64_t preeempted_sequence_id = -1;
248-
float generated_len = beam.get_generated_len() + (is_stop_token_id_hit(beam.m_token_id, sampling_params.stop_token_ids) ? 1 : 0); // HF counts EOS token in generation length
249-
beam.m_score /= std::pow(generated_len, sampling_params.length_penalty);
250-
251-
min_heap.push_back(beam);
252-
std::push_heap(min_heap.begin(), min_heap.end(), greater);
253-
assert(sampling_params.num_beams % sampling_params.num_beam_groups == 0 &&
254-
"number of beams should be divisible by number of groups");
255-
size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups;
256-
if (min_heap.size() > group_size) {
257-
std::pop_heap(min_heap.begin(), min_heap.end(), greater);
258-
preeempted_sequence_id = min_heap.back().m_sequence->get_id();
259-
min_heap.pop_back();
260-
}
261-
262-
return preeempted_sequence_id;
263-
}
264-
265-
void is_done(const ov::genai::GenerationConfig& sampling_params) {
266-
assert(sampling_params.num_beams % sampling_params.num_beam_groups == 0 &&
267-
"number of beams should be divisible by number of groups");
268-
size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups;
269-
if (min_heap.size() < group_size)
270-
return;
271-
272-
const Beam& best_running_sequence = ongoing.front(), & worst_finished_sequence = min_heap.front();
273-
size_t cur_len = best_running_sequence.m_sequence->get_generated_len();
274-
float best_sum_logprobs = best_running_sequence.m_score;
275-
float worst_score = worst_finished_sequence.m_score;
276-
switch (sampling_params.stop_criteria) {
277-
case ov::genai::StopCriteria::EARLY:
278-
done = true;
279-
return;
280-
case ov::genai::StopCriteria::HEURISTIC: {
281-
float highest_attainable_score = best_sum_logprobs / std::pow(float(cur_len), sampling_params.length_penalty);
282-
done = worst_score >= highest_attainable_score;
283-
return;
284-
}
285-
case ov::genai::StopCriteria::NEVER: {
286-
size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.max_new_tokens : cur_len;
287-
float highest_attainable_score = best_sum_logprobs / std::pow(float(length), sampling_params.length_penalty);
288-
done = worst_score >= highest_attainable_score;
289-
return;
290-
}
291-
default:
292-
OPENVINO_THROW("Beam search internal error: unkown mode");
293-
}
294-
}
78+
int64_t finish(Beam beam, const ov::genai::GenerationConfig& sampling_params);
79+
void is_done(const ov::genai::GenerationConfig& sampling_params);
29580
};
29681

29782
struct SamplerOutput {
@@ -311,121 +96,14 @@ class GroupBeamSearcher {
31196
explicit GroupBeamSearcher(SequenceGroup::Ptr sequence_group, Tokenizer tokenizer);
31297

31398
void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output);
314-
315-
void finalize(SamplerOutput& sampler_output) {
316-
for (Group& group : m_groups) {
317-
if (!group.done) {
318-
for (Beam& beam : group.ongoing) {
319-
uint64_t sequence_id = beam.m_sequence->get_id();
320-
321-
int64_t preempted_id = group.finish(beam, m_parameters);
322-
if (preempted_id >= 0) {
323-
// remove preempted one
324-
m_sequence_group->remove_sequence(preempted_id);
325-
}
326-
327-
// mark current sequence as finished
328-
beam.m_sequence->set_status(SequenceStatus::FINISHED);
329-
// Setting length since this function is used when sequence generated tokens number reaches max_new_tokens
330-
beam.m_sequence->set_finish_reason(GenerationFinishReason::LENGTH);
331-
// we also need to drop add ongoing / forked sequences from scheduler
332-
sampler_output.m_dropped_sequences.push_back(sequence_id);
333-
}
334-
}
335-
}
336-
}
99+
void finalize(SamplerOutput& sampler_output);
337100
};
338101

339102
class Sampler {
340-
341-
Logits _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1) {
342-
ov::Shape logits_shape = logits.get_shape();
343-
size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2];
344-
OPENVINO_ASSERT(batch_idx <= batch_size);
345-
size_t batch_offset = batch_idx * seq_len * vocab_size;
346-
size_t sequence_offset = (seq_len - 1) * vocab_size;
347-
float* logits_data = logits.data<float>() + batch_offset + sequence_offset;
348-
349-
return Logits{logits_data, vocab_size};
350-
}
351-
352-
Token _greedy_sample(const Logits& logits) const {
353-
// For greedy sampling we do not expect sorting or shrinking considered tokens
354-
// so we can operate directly on the data buffer
355-
float max_value = -std::numeric_limits<float>::infinity();
356-
size_t max_index = 0;
357-
for (size_t i = 0; i < logits.m_size; ++i) {
358-
if (logits.m_data[i] > max_value) {
359-
max_value = logits.m_data[i];
360-
max_index = i;
361-
}
362-
}
363-
364-
// apply log softmax to max value
365-
float log_sum = std::log(std::accumulate(
366-
logits.m_data, logits.m_data + logits.m_size, 0.0f, [max_value](float accumulated, float to_add) {
367-
return accumulated + std::exp(to_add - max_value);
368-
}));
369-
max_value = -log_sum;
370-
371-
return Token(max_value, max_index);
372-
}
373-
374-
std::vector<Token> _multinomial_sample(const Logits& logits, size_t num_tokens_per_sequence) {
375-
// If top_p or top_k was applied we use sorted vector, if not we go with original buffer.
376-
std::vector<float> multinomial_weights;
377-
multinomial_weights.reserve(logits.m_size);
378-
if (logits.is_vector_initialized())
379-
for (auto& logit: logits.m_vector) multinomial_weights.emplace_back(logit.m_log_prob);
380-
else
381-
multinomial_weights.assign(logits.m_data, logits.m_data + logits.m_size);
382-
383-
auto dist = std::discrete_distribution<size_t>(multinomial_weights.begin(), multinomial_weights.end()); // equivalent to multinomial with number of trials == 1
384-
385-
std::vector<Token> out_tokens;
386-
for (size_t token_idx = 0; token_idx < num_tokens_per_sequence; ++token_idx) {
387-
size_t element_to_pick = dist(rng_engine);
388-
if (logits.is_vector_initialized())
389-
out_tokens.push_back(logits.m_vector[element_to_pick]);
390-
else
391-
out_tokens.emplace_back(logits.m_data[element_to_pick], element_to_pick);
392-
}
393-
return out_tokens;
394-
}
395-
396-
std::vector<int64_t> _try_finish_generation(SequenceGroup::Ptr & sequence_group) {
397-
auto sampling_params = sequence_group->get_sampling_parameters();
398-
std::vector<int64_t> dropped_seq_ids;
399-
for (auto& running_sequence : sequence_group->get_running_sequences()) {
400-
const auto generated_len = running_sequence->get_generated_len();
401-
if (sampling_params.max_new_tokens == generated_len ||
402-
is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
403-
// stop sequence by max_new_tokens or stop token (eos included)
404-
running_sequence->set_status(SequenceStatus::FINISHED);
405-
406-
if (is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
407-
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
408-
} else if (sampling_params.max_new_tokens == generated_len) {
409-
running_sequence->set_finish_reason(GenerationFinishReason::LENGTH);
410-
}
411-
412-
dropped_seq_ids.push_back(running_sequence->get_id());
413-
continue;
414-
}
415-
416-
if (!sampling_params.stop_strings.empty()) {
417-
int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings);
418-
if (num_matched_last_tokens) {
419-
if (!sampling_params.include_stop_str_in_output)
420-
running_sequence->remove_last_tokens(num_matched_last_tokens);
421-
running_sequence->set_status(SequenceStatus::FINISHED);
422-
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
423-
dropped_seq_ids.push_back(running_sequence->get_id());
424-
}
425-
}
426-
}
427-
return dropped_seq_ids;
428-
}
103+
Logits _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1);
104+
Token _greedy_sample(const Logits& logits) const;
105+
std::vector<Token> _multinomial_sample(const Logits& logits, size_t num_tokens_per_sequence);
106+
std::vector<int64_t> _try_finish_generation(SequenceGroup::Ptr & sequence_group);
429107

430108
// request ID => beam search tracking information
431109
std::map<uint64_t, GroupBeamSearcher> m_beam_search_info;

0 commit comments

Comments
 (0)
Please sign in to comment.