From 917c41667b0abbd710efb945474e78fff65d6b8f Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Thu, 6 Jun 2024 00:30:11 -0700 Subject: [PATCH] fix: decode ending tokens one by one to handle unfinished tokens (#229) --- src/request/sequence.cpp | 26 +++++++++++++++++++++----- src/request/sequence.h | 7 +++++-- src/scheduler/response_handler.cpp | 2 +- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/request/sequence.cpp b/src/request/sequence.cpp index 3dda35bd..28a187f1 100644 --- a/src/request/sequence.cpp +++ b/src/request/sequence.cpp @@ -24,10 +24,10 @@ Sequence::Sequence(const std::string_view& prompt, : id_(next_id_.fetch_add(1)), last_token_time_(created_time), options_(option), - decoder_(prompt, - prompt_token_ids.size(), - option.echo, - option.skip_special_tokens), + incremental_decoder_(prompt, + prompt_token_ids.size(), + option.echo, + option.skip_special_tokens), num_kv_cache_tokens_(static_cast(EngineType::COUNT), 0) { CHECK(!prompt_token_ids.empty()) << "empty prompt token ids"; CHECK_GT(capacity, prompt_token_ids.size()) << "capacity too small"; @@ -128,7 +128,23 @@ size_t Sequence::validate_tokens(const Slice& accpeted_token_ids) { // decode the sequence to get delta text using the tokenizer std::string Sequence::decode_delta_text(const Slice& token_ids, const Tokenizer& tokenizer) { - return decoder_.decode(token_ids, tokenizer); + return incremental_decoder_.decode(token_ids, tokenizer); +} + +// decode the sequence to get text using the tokenizer +std::string Sequence::decode_text(const Tokenizer& tokenizer) { + const auto ids = token_ids(); + // leave 6 tokens for potential unfinished byte sequence from byte fallback + // tokenization + size_t start_idx = std::max(num_prompt_tokens_ + 1, ids.size() - 7); + std::stringstream ss; + // output leading tokens first + ss << incremental_decoder_.decode(ids.slice(0, start_idx), tokenizer); + // then decode one by one to avoid potential unfinished bytes + for (size_t i = start_idx; i < ids.size(); ++i) { + ss << incremental_decoder_.decode(ids.slice(0, i + 1), tokenizer); + } + return ss.str(); } void Sequence::append_blocks(const std::vector& new_blocks) { diff --git a/src/request/sequence.h b/src/request/sequence.h index 545c264a..3ccc76b0 100644 --- a/src/request/sequence.h +++ b/src/request/sequence.h @@ -161,8 +161,11 @@ class Sequence final { std::string decode_delta_text(const Slice& token_ids, const Tokenizer& tokenizer); + // decode the full sequence to get text using the tokenizer + std::string decode_text(const Tokenizer& tokenizer); + // get the offset of output tokens - size_t output_offset() const { return decoder_.output_offset(); } + size_t output_offset() const { return incremental_decoder_.output_offset(); } // check finish status, use cached value if not invalidated bool is_finished() const; @@ -215,7 +218,7 @@ class Sequence final { Options options_; // incremental decoder to decode the tokens - IncrementalDecoder decoder_; + IncrementalDecoder incremental_decoder_; // token ids generated for the sequence std::vector token_ids_; diff --git a/src/scheduler/response_handler.cpp b/src/scheduler/response_handler.cpp index ba3afd6c..cb43c278 100644 --- a/src/scheduler/response_handler.cpp +++ b/src/scheduler/response_handler.cpp @@ -74,7 +74,7 @@ void ResponseHandler::on_request_finish(std::unique_ptr request) { const auto finish_reason = seq.finish_reason(); // generate the final output AUTO_COUNTER(non_stream_decode_latency_seconds); - auto output = seq.decode_delta_text(seq.token_ids(), *tokenizer); + auto output = seq.decode_text(*tokenizer); outputs.push_back({i, std::move(output), to_string(finish_reason)}); } }