Skip to content

Commit 717f311

Browse files
committed
Align length_penalty
1 parent d6d4f00 commit 717f311

File tree

4 files changed

+146
-28
lines changed

4 files changed

+146
-28
lines changed

.github/workflows/causal_lm_cpp.yml

+33-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
- llm_bench/python/**
77
- text_generation/causal_lm/cpp/*
88
- thirdparty/openvino_tokenizers
9-
- '!**.md'
9+
- "!**.md"
1010
concurrency:
1111
group: ${{ github.workflow }}-${{ github.ref }}
1212
cancel-in-progress: true
@@ -94,6 +94,38 @@ jobs:
9494
predictions = predictions[:idx] + predictions[idx + len(ref):]
9595
"
9696
echo Hi passed
97+
98+
timeout 25s ./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v1.0/pytorch/dldt/FP16/ "return 0" > ./pred.txt
99+
python -c "
100+
import transformers
101+
with open('pred.txt', 'r') as file:
102+
predictions = file.read()
103+
tokenizer = transformers.LlamaTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0')
104+
tokenized = tokenizer('return 0', return_tensors='pt')
105+
for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False):
106+
ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) + '\n'
107+
idx = predictions.find(ref)
108+
if -1 == idx:
109+
raise RuntimeError(f'Missing "{ref=}" from predictions')
110+
predictions = predictions[:idx] + predictions[idx + len(ref):]
111+
"
112+
echo return 0 passed
113+
114+
./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v1.0/pytorch/dldt/FP16/ "你好! 你好嗎?" > ./pred.txt
115+
python -c "
116+
import transformers
117+
with open('pred.txt', 'r') as file:
118+
predictions = file.read()
119+
tokenizer = transformers.LlamaTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0')
120+
tokenized = tokenizer('你好! 你好嗎?', return_tensors='pt')
121+
for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False):
122+
ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) + '\n'
123+
idx = predictions.find(ref)
124+
if -1 == idx:
125+
raise RuntimeError(f'Missing "{ref=}" from predictions')
126+
predictions = predictions[:idx] + predictions[idx + len(ref):]
127+
"
128+
echo 你好! 你好嗎? passed
97129
cpp-beam_search_causal_lm-windows:
98130
runs-on: windows-latest
99131
steps:

.gitignore

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# build/artifact dirs
2+
_*
3+
[Bb]uild*/
4+
cmake-build*
5+
6+
# but ensure we don't skip __init__.py and __main__.py
7+
!__init__.py
8+
!__main__.py
9+
10+
# developer tools
11+
*.idea
12+
.vscode
13+
.vs/
14+
.vsconan/
15+
.DS_Store
16+
**/tags
17+
compile_commands.json
18+
bin/
19+
.local_vimrc
20+
.gdb_history
21+
.vimspector.json
22+
doc/
23+
docs/build_documentation/work_dir/
24+
temp/
25+
.repo/
26+
CMakeLists.txt.user
27+
docs/IE_PLUGIN_DG/html/
28+
CMakeUserPresets.json
29+
30+
*.project
31+
*.cproject
32+
*.pydevproject
33+
*.settings
34+
*/gen/
35+
*.swp
36+
/config.xml
37+
38+
# Python-specific
39+
*.?env*
40+
*.pyc
41+
__pycache__
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
BasedOnStyle: Google
2+
IndentWidth: 4
3+
UseTab: Never
4+
ColumnLimit: 120
5+
6+
Language: Cpp
7+
Standard: Cpp11
8+
9+
AccessModifierOffset: -4
10+
AlignConsecutiveMacros: true
11+
AllowAllArgumentsOnNextLine: false
12+
AllowAllConstructorInitializersOnNextLine: false
13+
AllowAllParametersOfDeclarationOnNextLine: false
14+
AllowShortFunctionsOnASingleLine: Empty
15+
AllowShortIfStatementsOnASingleLine: Never
16+
AllowShortLambdasOnASingleLine: Empty
17+
AllowShortLoopsOnASingleLine: false
18+
AlwaysBreakBeforeMultilineStrings: false
19+
BinPackArguments: false
20+
BinPackParameters: false
21+
CommentPragmas: '^#'
22+
DerivePointerAlignment: false
23+
FixNamespaceComments: true
24+
IndentCaseLabels: false
25+
IndentPPDirectives: AfterHash
26+
ForEachMacros:
27+
- foreach
28+
- FOREACH_CHILD

text_generation/causal_lm/cpp/group_beam_searcher.hpp

+44-27
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ std::vector<int64_t> kmp_search(const std::vector<int64_t>& haystack, const std:
4444
return res;
4545
}
4646

47-
struct Token {float log_prob; int64_t idx;};
47+
struct Token {
48+
float log_prob;
49+
int64_t idx;
50+
};
4851

4952
std::vector<Token> log_softmax(const ov::Tensor& logits, size_t batch_idx) {
5053
if (logits.get_shape().at(0) <= batch_idx) {
@@ -55,10 +58,10 @@ std::vector<Token> log_softmax(const ov::Tensor& logits, size_t batch_idx) {
5558
size_t sequence_offset = (logits.get_shape().at(1) - 1) * vocab_size;
5659
const float* beam_logits = logits.data<const float>() + batch_offset + sequence_offset;
5760
float max_logit = *std::max_element(beam_logits, beam_logits + vocab_size);
58-
float log_sum = std::log(std::accumulate(
59-
beam_logits, beam_logits + vocab_size, 0.0f, [max_logit](float accumulated, float to_add) {
61+
float log_sum = std::log(
62+
std::accumulate(beam_logits, beam_logits + vocab_size, 0.0f, [max_logit](float accumulated, float to_add) {
6063
return accumulated + std::exp(to_add - max_logit);
61-
}));
64+
}));
6265
std::vector<Token> tokens;
6366
tokens.reserve(vocab_size);
6467
for (size_t idx = 0; idx < vocab_size; ++idx) {
@@ -77,7 +80,7 @@ bool greater(const Beam& left, const Beam& right) {
7780
return left.score > right.score;
7881
}
7982

80-
enum class StopCriteria {early, heuristic, never};
83+
enum class StopCriteria { early, heuristic, never };
8184

8285
struct Parameters {
8386
std::vector<int64_t> prompt;
@@ -90,14 +93,24 @@ struct Parameters {
9093
size_t no_repeat_ngram_size = std::numeric_limits<size_t>::max();
9194
// There's no way to extract special token values from the tokenizer for now
9295
int64_t eos_token = 2;
93-
std::function<bool(const Beam&)> early_finish = [](const Beam&){return false;};
96+
std::function<bool(const Beam&)> early_finish = [](const Beam&) {
97+
return false;
98+
};
9499
};
95100

96101
struct Group {
97-
std::vector<Beam> ongoing; // Best beams in front
102+
std::vector<Beam> ongoing; // Best beams in front
98103
std::vector<Beam> min_heap; // The worst of the best completed beams is the first
99104
bool done = false;
100-
void finish(Beam&& beam, const Parameters& parameters) {
105+
106+
// finalize parameter introduced to match huggingface implementation
107+
void finish(Beam&& beam, const Parameters& parameters, const bool finalize = false) {
108+
size_t cur_len = ongoing.front().tokens.size();
109+
110+
if (!finalize) {
111+
cur_len += 1;
112+
}
113+
101114
beam.score /= std::pow(float(parameters.prompt.size() + beam.tokens.size()), parameters.length_penalty);
102115
min_heap.push_back(std::move(beam));
103116
std::push_heap(min_heap.begin(), min_heap.end(), greater);
@@ -110,30 +123,34 @@ struct Group {
110123
if (min_heap.size() < parameters.group_size) {
111124
return;
112125
}
113-
size_t cur_len = parameters.prompt.size() + ongoing.front().tokens.size();
126+
size_t cur_len = ongoing.front().tokens.size() + 1;
114127
float best_sum_logprobs = ongoing.front().score;
115128
float worst_score = min_heap.front().score;
116129
switch (parameters.stop_criteria) {
117-
case StopCriteria::early:
118-
done = true;
119-
return;
120-
case StopCriteria::heuristic: {
121-
float highest_attainable_score = best_sum_logprobs / std::pow(float(cur_len), parameters.length_penalty);
122-
done = worst_score >= highest_attainable_score;
123-
return;
124-
}
125-
case StopCriteria::never: {
126-
size_t length = parameters.length_penalty > 0.0 ? parameters.max_new_tokens : cur_len;
127-
float highest_attainable_score = best_sum_logprobs / std::pow(float(length), parameters.length_penalty);
128-
done = worst_score >= highest_attainable_score;
129-
return;
130-
}
131-
default: throw std::runtime_error("Never reached");
130+
case StopCriteria::early:
131+
done = true;
132+
return;
133+
case StopCriteria::heuristic: {
134+
float highest_attainable_score = best_sum_logprobs / std::pow(float(cur_len), parameters.length_penalty);
135+
done = worst_score >= highest_attainable_score;
136+
return;
137+
}
138+
case StopCriteria::never: {
139+
size_t length = parameters.length_penalty > 0.0 ? parameters.max_new_tokens : cur_len;
140+
float highest_attainable_score = best_sum_logprobs / std::pow(float(length), parameters.length_penalty);
141+
done = worst_score >= highest_attainable_score;
142+
return;
143+
}
144+
default:
145+
throw std::runtime_error("Never reached");
132146
}
133147
}
134148
};
135149

136-
struct TokenToBeam {int64_t token_idx; int32_t beam_idx;};
150+
struct TokenToBeam {
151+
int64_t token_idx;
152+
int32_t beam_idx;
153+
};
137154

138155
// GroupBeamSearcher processes logits prduced by a language model and accumulates beams using group beam search
139156
// algorithm. select_next_tokens() returns token ids selected by the algorithm and corresponding beam ids. These values
@@ -173,7 +190,7 @@ struct GroupBeamSearcher {
173190
continue;
174191
}
175192
std::vector<Beam> candidates;
176-
candidates.reserve(2 * parameters.group_size);
193+
candidates.reserve(parameters.group_size * 2 * parameters.group_size);
177194
for (const Beam& beam : group->ongoing) {
178195
std::vector<Token> tokens = log_softmax(logits, beam.global_beam_idx);
179196
for (auto prev_group = groups.cbegin(); prev_group != group; ++prev_group) {
@@ -251,7 +268,7 @@ std::vector<std::vector<Beam>> finalize(GroupBeamSearcher&& group_beam_searcher)
251268
for (Group& group : group_beam_searcher.groups) {
252269
if (!group.done) {
253270
for (Beam& beam : group.ongoing) {
254-
group.finish(std::move(beam), group_beam_searcher.parameters);
271+
group.finish(std::move(beam), group_beam_searcher.parameters, true);
255272
}
256273
}
257274
finalized.push_back(std::move(group.min_heap));

0 commit comments

Comments
 (0)