Skip to content

Commit

Permalink
fix: decode ending tokens one by one to handle unfinished tokens (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Jun 6, 2024
1 parent ab061e8 commit 917c416
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
26 changes: 21 additions & 5 deletions src/request/sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ Sequence::Sequence(const std::string_view& prompt,
: id_(next_id_.fetch_add(1)),
last_token_time_(created_time),
options_(option),
decoder_(prompt,
prompt_token_ids.size(),
option.echo,
option.skip_special_tokens),
incremental_decoder_(prompt,
prompt_token_ids.size(),
option.echo,
option.skip_special_tokens),
num_kv_cache_tokens_(static_cast<size_t>(EngineType::COUNT), 0) {
CHECK(!prompt_token_ids.empty()) << "empty prompt token ids";
CHECK_GT(capacity, prompt_token_ids.size()) << "capacity too small";
Expand Down Expand Up @@ -128,7 +128,23 @@ size_t Sequence::validate_tokens(const Slice<int64_t>& accpeted_token_ids) {
// decode the sequence to get delta text using the tokenizer
std::string Sequence::decode_delta_text(const Slice<int32_t>& token_ids,
const Tokenizer& tokenizer) {
return decoder_.decode(token_ids, tokenizer);
return incremental_decoder_.decode(token_ids, tokenizer);
}

// decode the sequence to get text using the tokenizer
std::string Sequence::decode_text(const Tokenizer& tokenizer) {
const auto ids = token_ids();
// leave 6 tokens for potential unfinished byte sequence from byte fallback
// tokenization
size_t start_idx = std::max(num_prompt_tokens_ + 1, ids.size() - 7);
std::stringstream ss;
// output leading tokens first
ss << incremental_decoder_.decode(ids.slice(0, start_idx), tokenizer);
// then decode one by one to avoid potential unfinished bytes
for (size_t i = start_idx; i < ids.size(); ++i) {
ss << incremental_decoder_.decode(ids.slice(0, i + 1), tokenizer);
}
return ss.str();
}

void Sequence::append_blocks(const std::vector<Block>& new_blocks) {
Expand Down
7 changes: 5 additions & 2 deletions src/request/sequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,11 @@ class Sequence final {
std::string decode_delta_text(const Slice<int32_t>& token_ids,
const Tokenizer& tokenizer);

// decode the full sequence to get text using the tokenizer
std::string decode_text(const Tokenizer& tokenizer);

// get the offset of output tokens
size_t output_offset() const { return decoder_.output_offset(); }
size_t output_offset() const { return incremental_decoder_.output_offset(); }

// check finish status, use cached value if not invalidated
bool is_finished() const;
Expand Down Expand Up @@ -215,7 +218,7 @@ class Sequence final {
Options options_;

// incremental decoder to decode the tokens
IncrementalDecoder decoder_;
IncrementalDecoder incremental_decoder_;

// token ids generated for the sequence
std::vector<int32_t> token_ids_;
Expand Down
2 changes: 1 addition & 1 deletion src/scheduler/response_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void ResponseHandler::on_request_finish(std::unique_ptr<Request> request) {
const auto finish_reason = seq.finish_reason();
// generate the final output
AUTO_COUNTER(non_stream_decode_latency_seconds);
auto output = seq.decode_delta_text(seq.token_ids(), *tokenizer);
auto output = seq.decode_text(*tokenizer);
outputs.push_back({i, std::move(output), to_string(finish_reason)});
}
}
Expand Down

0 comments on commit 917c416

Please sign in to comment.