Skip to content

Commit 2ed9889

Browse files
vshamporl-bat
andauthored
Token eviction (openvinotoolkit#757)
Co-authored-by: Liubov Talamanova <liubov.talamanova@intel.com>
1 parent c58ba64 commit 2ed9889

40 files changed

+3475
-499
lines changed

.github/workflows/linux.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ jobs:
189189
if: |
190190
always() &&
191191
(needs.openvino_download.outputs.status == 'success' || needs.openvino_build.result == 'success')
192-
timeout-minutes: 90
192+
timeout-minutes: 120
193193
defaults:
194194
run:
195195
shell: bash
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from pathlib import PosixPath
2+
import os
3+
import tempfile
4+
5+
import whowhatbench
6+
from whowhatbench.wwb import load_dataset
7+
from optimum.intel.openvino import OVModelForCausalLM
8+
9+
from openvino_genai import ContinuousBatchingPipeline, SchedulerConfig, GenerationConfig, CacheEvictionConfig, AggregationMode
10+
11+
from openvino_tokenizers import convert_tokenizer
12+
from openvino import serialize
13+
from transformers import AutoTokenizer
14+
15+
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
16+
MAX_NEW_TOKENS = 128
17+
SEQS_PER_REQUEST = 5
18+
MAX_SEQUENCES = 100
19+
20+
21+
model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True)
22+
tokenizer = AutoTokenizer.from_pretrained(model_id)
23+
model_path = PosixPath(tempfile.gettempdir()) / model_id
24+
model.save_pretrained(model_path)
25+
26+
ov_tokenizer, ov_detokenizer = convert_tokenizer(tokenizer, with_detokenizer=True, skip_special_tokens=True)
27+
serialize(ov_tokenizer, model_path / "openvino_tokenizer.xml")
28+
serialize(ov_detokenizer, model_path / "openvino_detokenizer.xml")
29+
30+
scheduler_config_noopt = SchedulerConfig()
31+
scheduler_config_noopt.num_kv_blocks = 300
32+
scheduler_config_noopt.dynamic_split_fuse = True
33+
scheduler_config_noopt.max_num_batched_tokens = 256
34+
scheduler_config_noopt.max_num_seqs = 256
35+
scheduler_config_noopt.enable_prefix_caching = False
36+
37+
scheduler_config_opt = SchedulerConfig()
38+
scheduler_config_opt.num_kv_blocks = 300
39+
scheduler_config_opt.dynamic_split_fuse = True
40+
scheduler_config_opt.max_num_batched_tokens = 256
41+
scheduler_config_opt.max_num_seqs = 256
42+
scheduler_config_opt.use_cache_eviction = True
43+
scheduler_config_opt.enable_prefix_caching = False
44+
eviction_config = CacheEvictionConfig(32, 32, 128, AggregationMode.NORM_SUM)
45+
scheduler_config_opt.cache_eviction_config = eviction_config
46+
47+
generation_config = GenerationConfig()
48+
generation_config.num_return_sequences = 1
49+
generation_config.max_new_tokens = MAX_NEW_TOKENS
50+
51+
data = load_dataset(path='squad', name=None, split='validation')["context"]
52+
data_dict = {"questions": list(dict({k: None for k in data}).keys())[:MAX_SEQUENCES]}
53+
54+
model_cb_noopt = ContinuousBatchingPipeline(model_path.absolute().as_posix(), scheduler_config_noopt, "CPU", {})
55+
model_cb_opt = ContinuousBatchingPipeline(model_path.absolute().as_posix(), scheduler_config_opt, "CPU", {})
56+
57+
58+
GT_DATA_FILE = 'gt_data.csv'
59+
60+
if os.path.exists(GT_DATA_FILE):
61+
evaluator = whowhatbench.Evaluator(base_model=model_cb_noopt, gt_data=GT_DATA_FILE, tokenizer=tokenizer,
62+
test_data=data_dict, generation_config=generation_config,
63+
max_new_tokens=MAX_NEW_TOKENS, seqs_per_request=3)
64+
else:
65+
evaluator = whowhatbench.Evaluator(base_model=model_cb_noopt, tokenizer=tokenizer, test_data=data_dict,
66+
generation_config=generation_config, max_new_tokens=MAX_NEW_TOKENS,
67+
seqs_per_request=3)
68+
evaluator.dump_gt('gt_data.csv')
69+
70+
71+
all_metrics_per_question, all_metrics = evaluator.score(model_cb_opt)
72+
73+
74+
print(all_metrics_per_question)
75+
print(all_metrics)
76+
77+
metrics = ["similarity", "SDT norm"]
78+
79+
for metric in metrics:
80+
worst_examples = evaluator.worst_examples(top_k=5, metric=metric)
81+
print("Metric: ", metric)
82+
for e in worst_examples:
83+
print("\t=========================")
84+
print(f"\t{metric}: ", e[metric])
85+
print("\tPrompt: ", e["prompt"])
86+
print("\tSource Model:\n ", "\t" + e["source_model"])
87+
print("\tOptimized Model:\n ", "\t" + e["optimized_model"])
88+
89+
pipeline_opt_metrics = model_cb_opt.get_metrics()
90+
pipeline_noopt_metrics = model_cb_noopt.get_metrics()
91+
92+
print(f"No-opt cache usage: max {pipeline_noopt_metrics.max_cache_usage:.3f}, avg {pipeline_noopt_metrics.avg_cache_usage:.3f}")
93+
print(f"Opt cache usage: max {pipeline_opt_metrics.max_cache_usage:.3f}, avg {pipeline_opt_metrics.avg_cache_usage:.3f}")
94+
max_optimization_ratio = (pipeline_noopt_metrics.max_cache_usage / pipeline_opt_metrics.max_cache_usage)
95+
avg_optimization_ratio = (pipeline_noopt_metrics.avg_cache_usage / pipeline_opt_metrics.avg_cache_usage)
96+
print(f"Optimization ratios: max {max_optimization_ratio:.3f}x, avg {avg_optimization_ratio:.3f}x")
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
transformers>=4.35.2
22
sentence-transformers>=2.2.2
33
openvino>=2024.3.0
4-
openvino-telemetry>=2024.3.0
4+
openvino-telemetry
55
optimum-intel>=1.14
6-
openvino-tokenizers>=2024.3.0
7-
openvino-genai>=2024.3.0
6+
openvino-tokenizers
87
pandas>=2.0.3
98
numpy>=1.23.5
109
tqdm>=4.66.1

llm_bench/python/who_what_benchmark/whowhatbench/evaluator.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def autodetect_language(model):
8181
"internlm": "cn",
8282
}
8383

84+
if not hasattr(model, "config"):
85+
return "en"
8486
return model2language.get(model.config.model_type, "en")
8587

8688

@@ -98,6 +100,9 @@ def __init__(
98100
num_samples=None,
99101
language=None,
100102
gen_answer_fn=None,
103+
generation_config=None,
104+
generation_config_base=None,
105+
seqs_per_request=None
101106
) -> None:
102107
assert (
103108
base_model is not None or gt_data is not None
@@ -109,6 +114,11 @@ def __init__(
109114
self.tokenizer = tokenizer
110115
self._crop_question = crop_question
111116
self.num_samples = num_samples
117+
self.generation_config = generation_config
118+
self.generation_config_base = generation_config
119+
self.seqs_per_request = seqs_per_request
120+
if self.generation_config is not None:
121+
assert self.seqs_per_request is not None
112122

113123
# Take language from the base model if provided
114124
self.language = language
@@ -117,7 +127,7 @@ def __init__(
117127
self.language = autodetect_language(base_model)
118128

119129
if base_model:
120-
self.gt_data = self._generate_data(base_model, gen_answer_fn)
130+
self.gt_data = self._generate_data(base_model, gen_answer_fn, generation_config=generation_config)
121131
else:
122132
self.gt_data = pd.read_csv(gt_data, keep_default_na=False)
123133

@@ -139,7 +149,7 @@ def dump_gt(self, csv_name: str):
139149
self.gt_data.to_csv(csv_name)
140150

141151
def score(self, model, gen_answer_fn=None):
142-
predictions = self._generate_data(model, gen_answer_fn)
152+
predictions = self._generate_data(model, gen_answer_fn, self.generation_config)
143153

144154
all_metrics_per_question = {}
145155
all_metrics = {}
@@ -179,9 +189,10 @@ def worst_examples(self, top_k: int = 5, metric="similarity"):
179189

180190
return res
181191

182-
def _generate_data(self, model, gen_answer_fn=None):
192+
def _generate_data(self, model, gen_answer_fn=None, generation_config=None):
183193
def default_gen_answer(model, tokenizer, question, max_new_tokens, crop_question):
184194
inputs = self.tokenizer(question, return_tensors="pt")
195+
185196
tokens = model.generate(**inputs, max_new_tokens=max_new_tokens)
186197
out = self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]
187198
return out[len(question) :] if crop_question else out
@@ -209,8 +220,21 @@ def default_gen_answer(model, tokenizer, question, max_new_tokens, crop_question
209220
answers = []
210221
prompts = questions.values if self.num_samples is None else questions.values[:self.num_samples]
211222

212-
for q in tqdm(prompts, desc="Evaluate pipeline"):
213-
answers.append(gen_answer_fn(model, self.tokenizer, q, self.max_new_tokens, self._crop_question))
223+
if generation_config is None:
224+
for q in tqdm(prompts, desc="Evaluate pipeline"):
225+
answers.append(gen_answer_fn(model, self.tokenizer, q, self.max_new_tokens, self._crop_question))
226+
else:
227+
with tqdm(total=len(questions.values)) as progress_bar:
228+
batch = []
229+
for q_idx, q in enumerate(questions.values):
230+
progress_bar.update(1)
231+
batch.append(q)
232+
if len(batch) == self.seqs_per_request or q_idx == len(questions.values) - 1:
233+
ans_batch = model.generate(batch, [generation_config] * len(batch))
234+
for ans in ans_batch:
235+
answers.append(ans.m_generation_ids[0])
236+
237+
batch.clear()
214238

215239
res_data = {"questions": list(prompts), "answers": answers}
216240
df = pd.DataFrame(res_data)

samples/cpp/continuous_batching_benchmark/continuous_batching_benchmark.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <nlohmann/json.hpp>
1515
#include <cxxopts.hpp>
1616

17+
#include "openvino/genai/cache_eviction.hpp"
1718
#include "openvino/genai/tokenizer.hpp"
1819
#include "openvino/genai/continuous_batching_pipeline.hpp"
1920
#include "openvino/genai/generation_handle.hpp"
@@ -440,6 +441,7 @@ int main(int argc, char* argv[]) try {
440441
("cache_size", "Size of memory used for KV cache in GB. Default: 16", cxxopts::value<size_t>()->default_value("16"))
441442
("device", "Target device to run the model. Default: CPU", cxxopts::value<std::string>()->default_value("CPU"))
442443
("device_config", "Plugin configuration JSON. Example: '{\"MODEL_DISTRIBUTION_POLICY\":\"TENSOR_PARALLEL\",\"PERF_COUNT\":true}' Default: {\"PERF_COUNT\":true}", cxxopts::value<std::string>()->default_value("{\"PERF_COUNT\":true}"))
444+
("use_cache_eviction", "Whether to use cache eviction", cxxopts::value<bool>()->default_value("false"))
443445
("h,help", "Print usage");
444446

445447
cxxopts::ParseResult result;
@@ -467,6 +469,7 @@ int main(int argc, char* argv[]) try {
467469
const std::string device = result["device"].as<std::string>();
468470
const std::string device_config = result["device_config"].as<std::string>();
469471
const size_t cache_size = result["cache_size"].as<size_t>();
472+
const bool use_cache_eviction = result["use_cache_eviction"].as<bool>();
470473

471474
// Create requests for generation
472475
Dataset dataset = filtered_dataset(models_path, dataset_path, num_prompts, max_input_len, max_output_len);
@@ -486,7 +489,11 @@ int main(int argc, char* argv[]) try {
486489
scheduler_config.cache_size = cache_size,
487490
scheduler_config.block_size = get_default_block_size(device),
488491
scheduler_config.dynamic_split_fuse = dynamic_split_fuse,
489-
scheduler_config.max_num_seqs = 256, // not used if dynamic_split_fuse=True
492+
scheduler_config.max_num_seqs = 256; // not used if dynamic_split_fuse=True
493+
if (use_cache_eviction) {
494+
scheduler_config.use_cache_eviction = true;
495+
scheduler_config.cache_eviction_config = ov::genai::CacheEvictionConfig(32, 32, 128, ov::genai::AggregationMode::NORM_SUM);
496+
}
490497

491498
std::cout << "Benchmarking parameters: " << std::endl;
492499
std::cout << "\tMax number of batched tokens: " << scheduler_config.max_num_batched_tokens << std::endl;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#pragma once
5+
6+
#include <cstddef>
7+
#include "openvino/openvino.hpp"
8+
9+
namespace ov::genai {
10+
/**
11+
* @brief Represents the mode of per-token score aggregation when determining least important tokens for eviction
12+
* from cache
13+
*/
14+
enum class AggregationMode {
15+
SUM, /**< In this mode the importance scores of each token will be summed after each step of generation */
16+
NORM_SUM /**< Same as SUM, but the importance scores are additionally divided by the lifetime (in tokens generated)
17+
* of a given token in cache */
18+
};
19+
20+
/**
21+
* @brief Configuration struct for the cache eviction algorithm.
22+
*/
23+
class CacheEvictionConfig {
24+
public:
25+
CacheEvictionConfig() {};
26+
CacheEvictionConfig(size_t start_size, size_t recent_size, size_t max_cache_size, AggregationMode aggregation_mode_) : aggregation_mode(aggregation_mode_), m_start_size(start_size), m_recent_size(recent_size), m_max_cache_size(max_cache_size) {
27+
OPENVINO_ASSERT(start_size, "CacheEvictionConfig.start_size must be non-zero");
28+
OPENVINO_ASSERT(recent_size, "CacheEvictionConfig.recent_size must be non-zero");
29+
OPENVINO_ASSERT(max_cache_size, "CacheEvictionConfig.max_cache_size must be non-zero");
30+
31+
OPENVINO_ASSERT(max_cache_size > (start_size + recent_size),
32+
"CacheEvictionConfig.max_cache_size must be larger than CacheEvictionConfig.start_size + CacheEvictionConfig.recent_size");
33+
m_evictable_size = m_max_cache_size - m_start_size - m_recent_size;
34+
35+
}
36+
37+
/** @return Number of tokens between the "start" and "recent" areas of KV cache that
38+
* will be considered for eviction. */
39+
std::size_t get_start_size() const {
40+
return m_start_size;
41+
}
42+
43+
/** @return Number of tokens between the "start" and "recent" areas of KV cache that
44+
* will be considered for eviction. */
45+
std::size_t get_recent_size() const {
46+
return m_recent_size;
47+
}
48+
49+
/** @return Number of tokens between the "start" and "recent" areas of KV cache that
50+
* will be considered for eviction. */
51+
std::size_t get_max_cache_size() const {
52+
return m_max_cache_size;
53+
}
54+
55+
/** @return Number of tokens between the "start" and "recent" areas of KV cache that
56+
* will be considered for eviction. */
57+
std::size_t get_evictable_size() const {
58+
return m_evictable_size;
59+
}
60+
61+
/** The mode used to compute the importance of tokens for eviction */
62+
AggregationMode aggregation_mode = AggregationMode::NORM_SUM;
63+
private:
64+
/** Number of tokens in the *beginning* of KV cache that should be retained
65+
* in the KV cache for this sequence during generation. Must be non-zero and a multiple of the KV cache block size for
66+
* this pipeline.*/
67+
std::size_t m_start_size = 32;
68+
69+
/** Number of tokens in the *end* of KV cache that should be retained
70+
* in the KV cache for this sequence during generation. Must be non-zero and a multiple of the KV cache block size for
71+
* this pipeline.*/
72+
std::size_t m_recent_size = 128;
73+
74+
/**
75+
* @brief Maximum cache size (in tokens) that can be occupied by a sequence with cache eviction enabled.
76+
* Actual occupied size may differ from this by no larger than (block_size) tokens.
77+
* Eviction area is computed from this size and the "start"/"recent" area sizes.
78+
* @return Total cache size (in tokens) allowed to be occupied by a sequence.
79+
*/
80+
std::size_t m_max_cache_size = 672;
81+
std::size_t m_evictable_size = 512;
82+
};
83+
}

src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp

+33-5
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,39 @@
1313
#include "openvino/genai/llm_pipeline.hpp"
1414
#include "openvino/genai/streamer_base.hpp"
1515
#include "openvino/genai/visibility.hpp"
16+
#include "cache_eviction.hpp"
1617

1718
namespace ov::genai {
18-
struct PipelineMetrics {
19-
// All requests as viewed by the pipeline
19+
20+
/**
21+
* @brief Contains general pipeline metrics, either aggregated throughout the lifetime of the generation pipeline
22+
* or measured at the previous generation step.
23+
*/
24+
struct PipelineMetrics {
25+
/**
26+
* Number of requests to be processed by the pipeline.
27+
*/
2028
size_t requests = 0;
21-
// Requests scheduled for processing
29+
30+
/**
31+
* Number of requests that were scheduled for processing at the previous step of the pipeline.
32+
*/
2233
size_t scheduled_requests = 0;
23-
// Percentage of KV cache usage
34+
35+
/**
36+
* Percentage of KV cache usage in the last generation step.
37+
*/
2438
float cache_usage = 0.0;
39+
40+
/**
41+
* Max KV cache usage during the lifetime of the pipeline in %
42+
*/
43+
float max_cache_usage = 0.0;
44+
45+
/**
46+
* Running average of the KV cache usage during the lifetime of the pipeline, with max window size of 1000 steps
47+
*/
48+
float avg_cache_usage = 0.0;
2549
};
2650

2751
class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
@@ -57,7 +81,11 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
5781

5882
ov::genai::GenerationConfig get_config() const;
5983

60-
PipelineMetrics get_metrics() const;
84+
/**
85+
* Allows to get the current pipeline metrics.
86+
* @return The struct with pipeline metrics for the previous generation step.
87+
*/
88+
ov::genai::PipelineMetrics get_metrics() const;
6189

6290
GenerationHandle add_request(uint64_t request_id, const ov::Tensor& input_ids, const ov::genai::GenerationConfig& sampling_params);
6391
GenerationHandle add_request(uint64_t request_id, const std::string& prompt, const ov::genai::GenerationConfig& sampling_params);

0 commit comments

Comments
 (0)