@@ -44,7 +44,10 @@ std::vector<int64_t> kmp_search(const std::vector<int64_t>& haystack, const std:
44
44
return res;
45
45
}
46
46
47
- struct Token {float log_prob; int64_t idx;};
47
+ struct Token {
48
+ float log_prob;
49
+ int64_t idx;
50
+ };
48
51
49
52
std::vector<Token> log_softmax (const ov::Tensor& logits, size_t batch_idx) {
50
53
if (logits.get_shape ().at (0 ) <= batch_idx) {
@@ -55,10 +58,10 @@ std::vector<Token> log_softmax(const ov::Tensor& logits, size_t batch_idx) {
55
58
size_t sequence_offset = (logits.get_shape ().at (1 ) - 1 ) * vocab_size;
56
59
const float * beam_logits = logits.data <const float >() + batch_offset + sequence_offset;
57
60
float max_logit = *std::max_element (beam_logits, beam_logits + vocab_size);
58
- float log_sum = std::log (std::accumulate (
59
- beam_logits, beam_logits + vocab_size, 0 .0f , [max_logit](float accumulated, float to_add) {
61
+ float log_sum = std::log (
62
+ std::accumulate ( beam_logits, beam_logits + vocab_size, 0 .0f , [max_logit](float accumulated, float to_add) {
60
63
return accumulated + std::exp (to_add - max_logit);
61
- }));
64
+ }));
62
65
std::vector<Token> tokens;
63
66
tokens.reserve (vocab_size);
64
67
for (size_t idx = 0 ; idx < vocab_size; ++idx) {
@@ -77,7 +80,7 @@ bool greater(const Beam& left, const Beam& right) {
77
80
return left.score > right.score ;
78
81
}
79
82
80
- enum class StopCriteria {early, heuristic, never};
83
+ enum class StopCriteria { early, heuristic, never };
81
84
82
85
struct Parameters {
83
86
std::vector<int64_t > prompt;
@@ -90,14 +93,24 @@ struct Parameters {
90
93
size_t no_repeat_ngram_size = std::numeric_limits<size_t >::max();
91
94
// There's no way to extract special token values from the tokenizer for now
92
95
int64_t eos_token = 2 ;
93
- std::function<bool (const Beam&)> early_finish = [](const Beam&){return false ;};
96
+ std::function<bool (const Beam&)> early_finish = [](const Beam&) {
97
+ return false ;
98
+ };
94
99
};
95
100
96
101
struct Group {
97
- std::vector<Beam> ongoing; // Best beams in front
102
+ std::vector<Beam> ongoing; // Best beams in front
98
103
std::vector<Beam> min_heap; // The worst of the best completed beams is the first
99
104
bool done = false ;
100
- void finish (Beam&& beam, const Parameters& parameters) {
105
+
106
+ // finalize parameter introduced to match huggingface implementation
107
+ void finish (Beam&& beam, const Parameters& parameters, const bool finalize = false ) {
108
+ size_t cur_len = ongoing.front ().tokens .size ();
109
+
110
+ if (!finalize) {
111
+ cur_len += 1 ;
112
+ }
113
+
101
114
beam.score /= std::pow (float (parameters.prompt .size () + beam.tokens .size ()), parameters.length_penalty );
102
115
min_heap.push_back (std::move (beam));
103
116
std::push_heap (min_heap.begin (), min_heap.end (), greater);
@@ -110,30 +123,34 @@ struct Group {
110
123
if (min_heap.size () < parameters.group_size ) {
111
124
return ;
112
125
}
113
- size_t cur_len = parameters. prompt . size () + ongoing.front ().tokens .size ();
126
+ size_t cur_len = ongoing.front ().tokens .size () + 1 ;
114
127
float best_sum_logprobs = ongoing.front ().score ;
115
128
float worst_score = min_heap.front ().score ;
116
129
switch (parameters.stop_criteria ) {
117
- case StopCriteria::early:
118
- done = true ;
119
- return ;
120
- case StopCriteria::heuristic: {
121
- float highest_attainable_score = best_sum_logprobs / std::pow (float (cur_len), parameters.length_penalty );
122
- done = worst_score >= highest_attainable_score;
123
- return ;
124
- }
125
- case StopCriteria::never: {
126
- size_t length = parameters.length_penalty > 0.0 ? parameters.max_new_tokens : cur_len;
127
- float highest_attainable_score = best_sum_logprobs / std::pow (float (length), parameters.length_penalty );
128
- done = worst_score >= highest_attainable_score;
129
- return ;
130
- }
131
- default : throw std::runtime_error (" Never reached" );
130
+ case StopCriteria::early:
131
+ done = true ;
132
+ return ;
133
+ case StopCriteria::heuristic: {
134
+ float highest_attainable_score = best_sum_logprobs / std::pow (float (cur_len), parameters.length_penalty );
135
+ done = worst_score >= highest_attainable_score;
136
+ return ;
137
+ }
138
+ case StopCriteria::never: {
139
+ size_t length = parameters.length_penalty > 0.0 ? parameters.max_new_tokens : cur_len;
140
+ float highest_attainable_score = best_sum_logprobs / std::pow (float (length), parameters.length_penalty );
141
+ done = worst_score >= highest_attainable_score;
142
+ return ;
143
+ }
144
+ default :
145
+ throw std::runtime_error (" Never reached" );
132
146
}
133
147
}
134
148
};
135
149
136
- struct TokenToBeam {int64_t token_idx; int32_t beam_idx;};
150
+ struct TokenToBeam {
151
+ int64_t token_idx;
152
+ int32_t beam_idx;
153
+ };
137
154
138
155
// GroupBeamSearcher processes logits prduced by a language model and accumulates beams using group beam search
139
156
// algorithm. select_next_tokens() returns token ids selected by the algorithm and corresponding beam ids. These values
@@ -173,7 +190,7 @@ struct GroupBeamSearcher {
173
190
continue ;
174
191
}
175
192
std::vector<Beam> candidates;
176
- candidates.reserve (2 * parameters.group_size );
193
+ candidates.reserve (parameters. group_size * 2 * parameters.group_size );
177
194
for (const Beam& beam : group->ongoing ) {
178
195
std::vector<Token> tokens = log_softmax (logits, beam.global_beam_idx );
179
196
for (auto prev_group = groups.cbegin (); prev_group != group; ++prev_group) {
@@ -251,7 +268,7 @@ std::vector<std::vector<Beam>> finalize(GroupBeamSearcher&& group_beam_searcher)
251
268
for (Group& group : group_beam_searcher.groups ) {
252
269
if (!group.done ) {
253
270
for (Beam& beam : group.ongoing ) {
254
- group.finish (std::move (beam), group_beam_searcher.parameters );
271
+ group.finish (std::move (beam), group_beam_searcher.parameters , true );
255
272
}
256
273
}
257
274
finalized.push_back (std::move (group.min_heap ));
0 commit comments