@@ -48,7 +48,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(
48
48
auto k = static_cast <float >(draft_model_cache_size) / (main_model_cache_size + draft_model_cache_size);
49
49
50
50
size_t main_cache_size = main_scheduler_config.cache_size * (1 - k),
51
- draft_cache_size = main_scheduler_config.cache_size * k ;
51
+ draft_cache_size = main_scheduler_config.cache_size - main_cache_size ;
52
52
if (draft_cache_size == 0 ) {
53
53
main_cache_size -= main_cache_size > 1 ? 1 : 0 ;
54
54
draft_cache_size = 1 ;
@@ -158,6 +158,10 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() {
158
158
m_draft_generations.erase (request_id);
159
159
}
160
160
auto updated_seq_info = update_sequence_info[request_id];
161
+ // several prompt phase
162
+ if (updated_seq_info.inserted_tokens_cnt == 0 ) {
163
+ continue ;
164
+ }
161
165
float acceptance_rate = 1 - static_cast <float >(updated_seq_info.removed_tokens_cnt ) / updated_seq_info.inserted_tokens_cnt ;
162
166
m_sd_metrics.update_acceptance_rate (request_id, acceptance_rate * 100 );
163
167
m_sd_metrics.update_draft_accepted_tokens (request_id, (updated_seq_info.inserted_tokens_cnt - updated_seq_info.removed_tokens_cnt ));
@@ -203,6 +207,10 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
203
207
while (has_non_finished_requests () && continue_generation) {
204
208
step ();
205
209
if (streamer_ptr) {
210
+ // not generated tokens like several prompt phase
211
+ if (!main_generations.at (0 ).get ()->can_read ()) {
212
+ continue ;
213
+ }
206
214
std::unordered_map<uint64_t , GenerationOutput> token = main_generations.at (0 ).get ()->back ();
207
215
OPENVINO_ASSERT (1 <= token.size ());
208
216
OPENVINO_ASSERT (1 <= token.begin ()->second .generated_ids .size ());
0 commit comments