22
22
23
23
namespace ov ::genai {
24
24
// Modifyed Knuth–Morris–Pratt algorithm which returns tokens following after every needle occurance in haystack
25
- inline std::vector<int64_t > kmp_search (const std::vector<int64_t >& haystack, const std::vector<int64_t >& needle) {
26
- if (needle.empty ()) { // no_repeat_ngram_size == 1, ban every token
27
- return {haystack.begin (), haystack.end ()};
28
- }
29
- std::vector<int > partial_match_table (needle.size () + 1 , -1 );
30
- int cnd = 0 ;
31
- for (size_t pos = 1 ; pos < needle.size (); ++pos) {
32
- if (needle.at (pos) == needle.at (size_t (cnd))) {
33
- partial_match_table.at (pos) = partial_match_table.at (size_t (cnd));
34
- } else {
35
- partial_match_table.at (pos) = cnd;
36
- while (cnd >= 0 && needle.at (pos) != needle.at (size_t (cnd))) {
37
- cnd = partial_match_table.at (size_t (cnd));
38
- }
39
- }
40
- ++cnd;
41
- }
42
- partial_match_table.back () = cnd;
43
- std::vector<int64_t > res;
44
- size_t haystack_id = 0 ;
45
- int needle_id = 0 ;
46
- while (haystack_id < haystack.size () - 1 ) {
47
- if (needle.at (size_t (needle_id)) == haystack.at (haystack_id)) {
48
- ++haystack_id;
49
- ++needle_id;
50
- if (needle_id == int (needle.size ())) {
51
- res.push_back (haystack.at (haystack_id));
52
- needle_id = partial_match_table.at (size_t (needle_id));
53
- }
54
- } else {
55
- needle_id = partial_match_table.at (size_t (needle_id));
56
- if (needle_id < 0 ) {
57
- ++haystack_id;
58
- ++needle_id;
59
- }
60
- }
61
- }
62
- return res;
63
- }
25
+ std::vector<int64_t > kmp_search (const std::vector<int64_t >& haystack, const std::vector<int64_t >& needle);
64
26
65
- inline std::vector<Token> log_softmax (const ov::Tensor& logits, size_t batch_idx) {
66
- ov::Shape shape = logits.get_shape ();
67
- OPENVINO_ASSERT (shape.size () == 3 );
68
- size_t batch = shape[0 ], seq_len = shape[1 ], vocab_size = shape[2 ];
69
- OPENVINO_ASSERT (batch_idx < batch, " Logits batch size doesn't match the number of beams" );
70
-
71
- size_t batch_offset = batch_idx * seq_len * vocab_size, sequence_offset = (seq_len - 1 ) * vocab_size;
72
- const float * beam_logits = logits.data <const float >() + batch_offset + sequence_offset;
73
- float max_logit = *std::max_element (beam_logits, beam_logits + vocab_size);
74
- float log_sum = std::log (std::accumulate (
75
- beam_logits, beam_logits + vocab_size, 0 .0f , [max_logit](float accumulated, float to_add) {
76
- return accumulated + std::exp (to_add - max_logit);
77
- }));
78
-
79
- std::vector<Token> tokens;
80
- tokens.reserve (vocab_size);
81
- for (size_t idx = 0 ; idx < vocab_size; ++idx)
82
- tokens.push_back ({beam_logits[idx] - max_logit - log_sum, int64_t (idx)});
83
-
84
- return tokens;
85
- }
27
+ std::vector<Token> log_softmax (const ov::Tensor& logits, size_t batch_idx);
86
28
87
- inline std::vector<int64_t >
88
- wrap_tokens (const std::vector<int64_t >& tokens, const std::vector<int64_t >& prefix_tokens, const std::vector<int64_t >& suffix_tokens) {
89
- std::vector<int64_t > all_tokens = prefix_tokens;
90
- all_tokens.insert (all_tokens.end (), tokens.begin (), tokens.end ());
91
- all_tokens.insert (all_tokens.end (), suffix_tokens.begin (), suffix_tokens.end ());
92
- return all_tokens;
93
- }
29
+ std::vector<int64_t > wrap_tokens (const std::vector<int64_t >& tokens, const std::vector<int64_t >& prefix_tokens, const std::vector<int64_t >& suffix_tokens);
94
30
95
- inline std::string clean_wrapped_text (const std::string& wrapped_text, const std::string& prefix, const std::string& suffix) {
96
- auto prefix_pos = wrapped_text.find (prefix);
97
- OPENVINO_ASSERT (prefix_pos != std::string::npos);
98
- auto suffix_pos = wrapped_text.rfind (suffix);
99
- OPENVINO_ASSERT (suffix_pos != std::string::npos);
100
- auto clean_text_start = prefix_pos + prefix.size ();
101
- auto clean_text_length = suffix_pos - clean_text_start;
102
- std::string clean_text = wrapped_text.substr (clean_text_start, clean_text_length);
103
- return clean_text;
104
- }
31
+ std::string clean_wrapped_text (const std::string& wrapped_text, const std::string& prefix, const std::string& suffix);
105
32
106
33
// Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned.
107
- inline int
108
- match_stop_string (Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings) {
109
- /*
110
- For catching stop_string hit we run comparisons character-wise to catch cases where stop string
111
- overlaps with part of another token on both sides or is just a part of a single token.
112
- For every stop_string we iterate over generated tokens starting from the last one and going backwards.
113
- Every token is wrapped with prefix tokens to ensure tokenizer doesn't remove prefix whitespace of the actual token.
114
- After that all tokens are decoded and prefix is removed from the decoded text, so we end up with decoded token.
115
- Its characters are compared to the stop_string character at a current_position
116
- (position of a character in the stop_string counting from the last one) - at the begining position is 0.
117
- When characters match we increase current_position and check if we have a full match already, if not we continue.
118
- If we have already matched some characters (current_position > 0) and next character is not matching
119
- before we reach the full match, then we reset current_position to 0.
120
- */
121
- std::string prefix = " a" ;
122
- auto prefix_ov = tokenizer.encode (prefix).input_ids ;
123
- std::vector<int64_t > prefix_tokens (prefix_ov.data <int64_t >(), prefix_ov.data <int64_t >() + prefix_ov.get_size ());
124
- std::string suffix = " b" ;
125
- auto suffix_ov = tokenizer.encode (suffix).input_ids ;
126
- std::vector<int64_t > suffix_tokens (suffix_ov.data <int64_t >(), suffix_ov.data <int64_t >() + suffix_ov.get_size ());
127
-
128
- // Since whitespace can be added at the beginning of the suffix we also try to capture that behavior here
129
- // and get suffix string that will actually be part of the decoded string so we can remove it correctly
130
- auto wrapped_suffix_tokens = suffix_tokens;
131
- wrapped_suffix_tokens.insert (wrapped_suffix_tokens.begin (), prefix_tokens.begin (), prefix_tokens.end ());
132
- std::string wrapped_suffix = tokenizer.decode (wrapped_suffix_tokens);
133
- auto wrapper_pos = wrapped_suffix.find (prefix);
134
- suffix = wrapped_suffix.substr (wrapper_pos + prefix.size ());
135
-
136
- for (auto stop_string: stop_strings) {
137
- int current_position = 0 ;
138
- int num_matched_tokens = 0 ;
139
- // Getting reverse iterator to check tokens starting from the last one generated and going backwards
140
- auto generated_tokens_rit = generated_tokens.rbegin ();
141
- std::vector<int64_t > tokens_buffer;
142
- while (generated_tokens_rit != generated_tokens.rend ()) {
143
- num_matched_tokens++;
144
- tokens_buffer.insert (tokens_buffer.begin (), *generated_tokens_rit);
145
-
146
- std::vector<int64_t > wrapped_tokens = wrap_tokens (tokens_buffer, prefix_tokens, suffix_tokens);
147
- std::string wrapped_text = tokenizer.decode (wrapped_tokens);
148
- std::string clean_text = clean_wrapped_text (wrapped_text, prefix, suffix);
149
-
150
- if (clean_text == " " || (clean_text.size () >= 3 && (clean_text.compare (clean_text.size () - 3 , 3 , " �" ) == 0 ))) {
151
- generated_tokens_rit++;
152
- continue ;
153
- } else {
154
- tokens_buffer.clear ();
155
- }
156
- // Checking clean_text characters starting from the last one
157
- for (auto clean_text_rit = clean_text.rbegin (); clean_text_rit != clean_text.rend (); clean_text_rit++) {
158
- // On character match increment current_position for the next comparisons
159
- if (*clean_text_rit == *(stop_string.rbegin () + current_position)) {
160
- current_position++;
161
- // If this is the last character from the stop_string we have a match
162
- if ((stop_string.rbegin () + current_position) == stop_string.rend ()) {
163
- return num_matched_tokens;
164
- }
165
- } else if (current_position) {
166
- // Already found matching characters, but the last one didn't match, so we reset current_position
167
- current_position = 0 ;
168
- // Looking for the match will start over from this character so we decrement iterator
169
- clean_text_rit--;
170
- }
171
- }
172
- generated_tokens_rit++;
173
- }
174
- }
175
- return 0 ;
176
- }
34
+ int match_stop_string (Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings);
177
35
178
36
// Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned.
179
37
// Number of tokens might not be exact as if there's no direct token match, we decode generated tokens incrementally expanding decoding scope
180
38
// with 4 next tokens with each iteration until we check all tokens.
181
- inline int match_stop_string2 (Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings) {
182
- for (auto stop_string: stop_strings) {
183
- auto stop_tokens_ov = tokenizer.encode (stop_string).input_ids ;
184
- size_t num_tokens = stop_tokens_ov.get_size ();
185
- if (num_tokens > generated_tokens.size ())
186
- continue ;
187
-
188
- // Check direct token match
189
- std::vector<int64_t > stop_tokens (stop_tokens_ov.data <int64_t >(), stop_tokens_ov.data <int64_t >() + num_tokens);
190
- std::vector<int64_t > last_generated_tokens (generated_tokens.end ()-num_tokens, generated_tokens.end ());
191
- if (stop_tokens == last_generated_tokens)
192
- return num_tokens;
193
-
194
- // Continue checking chunks of 4 tokens
195
- num_tokens += 4 ;
196
- while (num_tokens <= generated_tokens.size ()) {
197
- std::vector<int64_t > last_generated_tokens (generated_tokens.end ()-num_tokens, generated_tokens.end ());
198
- std::string decoded_last_tokens = tokenizer.decode (last_generated_tokens);
199
- if (decoded_last_tokens.find (stop_string) != std::string::npos) {
200
- return num_tokens;
201
- }
202
- num_tokens += 4 ;
203
- }
204
- }
205
-
206
- return 0 ;
207
- }
39
+ int match_stop_string2 (Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings);
208
40
209
41
// Handle stop_token_ids
210
42
inline bool is_stop_token_id_hit (int64_t generated_token, const std::set<int64_t > & stop_token_ids) {
@@ -243,55 +75,8 @@ struct Group {
243
75
std::vector<Beam> min_heap; // The worst of the best completed beams is the first
244
76
bool done = false ;
245
77
246
- int64_t finish (Beam beam, const ov::genai::GenerationConfig& sampling_params) {
247
- int64_t preeempted_sequence_id = -1 ;
248
- float generated_len = beam.get_generated_len () + (is_stop_token_id_hit (beam.m_token_id , sampling_params.stop_token_ids ) ? 1 : 0 ); // HF counts EOS token in generation length
249
- beam.m_score /= std::pow (generated_len, sampling_params.length_penalty );
250
-
251
- min_heap.push_back (beam);
252
- std::push_heap (min_heap.begin (), min_heap.end (), greater);
253
- assert (sampling_params.num_beams % sampling_params.num_beam_groups == 0 &&
254
- " number of beams should be divisible by number of groups" );
255
- size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups ;
256
- if (min_heap.size () > group_size) {
257
- std::pop_heap (min_heap.begin (), min_heap.end (), greater);
258
- preeempted_sequence_id = min_heap.back ().m_sequence ->get_id ();
259
- min_heap.pop_back ();
260
- }
261
-
262
- return preeempted_sequence_id;
263
- }
264
-
265
- void is_done (const ov::genai::GenerationConfig& sampling_params) {
266
- assert (sampling_params.num_beams % sampling_params.num_beam_groups == 0 &&
267
- " number of beams should be divisible by number of groups" );
268
- size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups ;
269
- if (min_heap.size () < group_size)
270
- return ;
271
-
272
- const Beam& best_running_sequence = ongoing.front (), & worst_finished_sequence = min_heap.front ();
273
- size_t cur_len = best_running_sequence.m_sequence ->get_generated_len ();
274
- float best_sum_logprobs = best_running_sequence.m_score ;
275
- float worst_score = worst_finished_sequence.m_score ;
276
- switch (sampling_params.stop_criteria ) {
277
- case ov::genai::StopCriteria::EARLY:
278
- done = true ;
279
- return ;
280
- case ov::genai::StopCriteria::HEURISTIC: {
281
- float highest_attainable_score = best_sum_logprobs / std::pow (float (cur_len), sampling_params.length_penalty );
282
- done = worst_score >= highest_attainable_score;
283
- return ;
284
- }
285
- case ov::genai::StopCriteria::NEVER: {
286
- size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.max_new_tokens : cur_len;
287
- float highest_attainable_score = best_sum_logprobs / std::pow (float (length), sampling_params.length_penalty );
288
- done = worst_score >= highest_attainable_score;
289
- return ;
290
- }
291
- default :
292
- OPENVINO_THROW (" Beam search internal error: unkown mode" );
293
- }
294
- }
78
+ int64_t finish (Beam beam, const ov::genai::GenerationConfig& sampling_params);
79
+ void is_done (const ov::genai::GenerationConfig& sampling_params);
295
80
};
296
81
297
82
struct SamplerOutput {
@@ -311,121 +96,14 @@ class GroupBeamSearcher {
311
96
explicit GroupBeamSearcher (SequenceGroup::Ptr sequence_group, Tokenizer tokenizer);
312
97
313
98
void select_next_tokens (const ov::Tensor& logits, SamplerOutput& sampler_output);
314
-
315
- void finalize (SamplerOutput& sampler_output) {
316
- for (Group& group : m_groups) {
317
- if (!group.done ) {
318
- for (Beam& beam : group.ongoing ) {
319
- uint64_t sequence_id = beam.m_sequence ->get_id ();
320
-
321
- int64_t preempted_id = group.finish (beam, m_parameters);
322
- if (preempted_id >= 0 ) {
323
- // remove preempted one
324
- m_sequence_group->remove_sequence (preempted_id);
325
- }
326
-
327
- // mark current sequence as finished
328
- beam.m_sequence ->set_status (SequenceStatus::FINISHED);
329
- // Setting length since this function is used when sequence generated tokens number reaches max_new_tokens
330
- beam.m_sequence ->set_finish_reason (GenerationFinishReason::LENGTH);
331
- // we also need to drop add ongoing / forked sequences from scheduler
332
- sampler_output.m_dropped_sequences .push_back (sequence_id);
333
- }
334
- }
335
- }
336
- }
99
+ void finalize (SamplerOutput& sampler_output);
337
100
};
338
101
339
102
class Sampler {
340
-
341
- Logits _get_logit_vector (ov::Tensor logits, size_t batch_idx = 1 ) {
342
- ov::Shape logits_shape = logits.get_shape ();
343
- size_t batch_size = logits_shape[0 ], seq_len = logits_shape[1 ], vocab_size = logits_shape[2 ];
344
- OPENVINO_ASSERT (batch_idx <= batch_size);
345
- size_t batch_offset = batch_idx * seq_len * vocab_size;
346
- size_t sequence_offset = (seq_len - 1 ) * vocab_size;
347
- float * logits_data = logits.data <float >() + batch_offset + sequence_offset;
348
-
349
- return Logits{logits_data, vocab_size};
350
- }
351
-
352
- Token _greedy_sample (const Logits& logits) const {
353
- // For greedy sampling we do not expect sorting or shrinking considered tokens
354
- // so we can operate directly on the data buffer
355
- float max_value = -std::numeric_limits<float >::infinity ();
356
- size_t max_index = 0 ;
357
- for (size_t i = 0 ; i < logits.m_size ; ++i) {
358
- if (logits.m_data [i] > max_value) {
359
- max_value = logits.m_data [i];
360
- max_index = i;
361
- }
362
- }
363
-
364
- // apply log softmax to max value
365
- float log_sum = std::log (std::accumulate (
366
- logits.m_data , logits.m_data + logits.m_size , 0 .0f , [max_value](float accumulated, float to_add) {
367
- return accumulated + std::exp (to_add - max_value);
368
- }));
369
- max_value = -log_sum;
370
-
371
- return Token (max_value, max_index);
372
- }
373
-
374
- std::vector<Token> _multinomial_sample (const Logits& logits, size_t num_tokens_per_sequence) {
375
- // If top_p or top_k was applied we use sorted vector, if not we go with original buffer.
376
- std::vector<float > multinomial_weights;
377
- multinomial_weights.reserve (logits.m_size );
378
- if (logits.is_vector_initialized ())
379
- for (auto & logit: logits.m_vector ) multinomial_weights.emplace_back (logit.m_log_prob );
380
- else
381
- multinomial_weights.assign (logits.m_data , logits.m_data + logits.m_size );
382
-
383
- auto dist = std::discrete_distribution<size_t >(multinomial_weights.begin (), multinomial_weights.end ()); // equivalent to multinomial with number of trials == 1
384
-
385
- std::vector<Token> out_tokens;
386
- for (size_t token_idx = 0 ; token_idx < num_tokens_per_sequence; ++token_idx) {
387
- size_t element_to_pick = dist (rng_engine);
388
- if (logits.is_vector_initialized ())
389
- out_tokens.push_back (logits.m_vector [element_to_pick]);
390
- else
391
- out_tokens.emplace_back (logits.m_data [element_to_pick], element_to_pick);
392
- }
393
- return out_tokens;
394
- }
395
-
396
- std::vector<int64_t > _try_finish_generation (SequenceGroup::Ptr & sequence_group) {
397
- auto sampling_params = sequence_group->get_sampling_parameters ();
398
- std::vector<int64_t > dropped_seq_ids;
399
- for (auto & running_sequence : sequence_group->get_running_sequences ()) {
400
- const auto generated_len = running_sequence->get_generated_len ();
401
- if (sampling_params.max_new_tokens == generated_len ||
402
- is_stop_token_id_hit (running_sequence->get_generated_ids ().back (), sampling_params.stop_token_ids ) && !sampling_params.ignore_eos ) {
403
- // stop sequence by max_new_tokens or stop token (eos included)
404
- running_sequence->set_status (SequenceStatus::FINISHED);
405
-
406
- if (is_stop_token_id_hit (running_sequence->get_generated_ids ().back (), sampling_params.stop_token_ids ) && !sampling_params.ignore_eos ) {
407
- running_sequence->set_finish_reason (GenerationFinishReason::STOP);
408
- } else if (sampling_params.max_new_tokens == generated_len) {
409
- running_sequence->set_finish_reason (GenerationFinishReason::LENGTH);
410
- }
411
-
412
- dropped_seq_ids.push_back (running_sequence->get_id ());
413
- continue ;
414
- }
415
-
416
- if (!sampling_params.stop_strings .empty ()) {
417
- int num_matched_last_tokens = match_stop_string (m_tokenizer, running_sequence->get_generated_ids (), sampling_params.stop_strings );
418
- if (num_matched_last_tokens) {
419
- if (!sampling_params.include_stop_str_in_output )
420
- running_sequence->remove_last_tokens (num_matched_last_tokens);
421
- running_sequence->set_status (SequenceStatus::FINISHED);
422
- running_sequence->set_finish_reason (GenerationFinishReason::STOP);
423
- dropped_seq_ids.push_back (running_sequence->get_id ());
424
- }
425
- }
426
- }
427
- return dropped_seq_ids;
428
- }
103
+ Logits _get_logit_vector (ov::Tensor logits, size_t batch_idx = 1 );
104
+ Token _greedy_sample (const Logits& logits) const ;
105
+ std::vector<Token> _multinomial_sample (const Logits& logits, size_t num_tokens_per_sequence);
106
+ std::vector<int64_t > _try_finish_generation (SequenceGroup::Ptr & sequence_group);
429
107
430
108
// request ID => beam search tracking information
431
109
std::map<uint64_t , GroupBeamSearcher> m_beam_search_info;
0 commit comments