Skip to content

Commit b31b6a1

Browse files
Enable print properties of compiled model in genai API (openvinotoolkit#1289)
When setting the environment variable OPENVINO_LOG_LEVEL > ov::log::Level::WARNING, the properties of the compiled model can be printed in genai API. When the device is CPU, the properties of the compiled model are as follows: Model: Stateful LLM model NETWORK_NAME: Model0 OPTIMAL_NUMBER_OF_INFER_REQUESTS: 1 NUM_STREAMS: 1 INFERENCE_NUM_THREADS: 48 PERF_COUNT: NO INFERENCE_PRECISION_HINT: bf16 PERFORMANCE_HINT: LATENCY EXECUTION_MODE_HINT: PERFORMANCE PERFORMANCE_HINT_NUM_REQUESTS: 0 ENABLE_CPU_PINNING: YES SCHEDULING_CORE_TYPE: ANY_CORE MODEL_DISTRIBUTION_POLICY: ENABLE_HYPER_THREADING: NO EXECUTION_DEVICES: CPU CPU_DENORMALS_OPTIMIZATION: NO LOG_LEVEL: LOG_NONE CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1 DYNAMIC_QUANTIZATION_GROUP_SIZE: 32 KV_CACHE_PRECISION: f16 AFFINITY: CORE EXECUTION_DEVICES: CPU: Intel(R) Xeon(R) Platinum 8468 [stable_diffusion_compiled_model_log.txt](https://github.com/user-attachments/files/18120641/stable_diffusion_compiled_model_log.txt) --------- Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
1 parent 79f64a6 commit b31b6a1

24 files changed

+152
-29
lines changed

.github/workflows/llm_bench-python.yml

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ jobs:
6161
SRC_DIR: ${{ github.workspace }}
6262
LLM_BENCH_PYPATH: ${{ github.workspace }}/tools/llm_bench
6363
WWB_PATH: ${{ github.workspace }}/tools/who_what_benchmark
64+
OPENVINO_LOG_LEVEL: 3
6465

6566
steps:
6667
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

src/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,7 @@ For information on how OpenVINO™ GenAI works, refer to the [How It Works Secti
403403
## Supported Models
404404

405405
For a list of supported models, refer to the [Supported Models Section](./docs/SUPPORTED_MODELS.md).
406+
407+
## Debug Log
408+
409+
For using debug log, refer to [DEBUG Log](./doc/DEBUG_LOG.md).

src/cpp/src/continuous_batching_impl.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init(
4646
const ov::AnyMap& properties,
4747
const DeviceConfig& device_config,
4848
ov::Core& core) {
49-
ov::InferRequest infer_request = core.compile_model(model, device_config.get_device(), properties).create_infer_request();
49+
auto compiled_model = core.compile_model(model, device_config.get_device(), properties);
50+
ov::genai::utils::print_compiled_model_properties(compiled_model, "LLM with Paged Attention");
51+
ov::InferRequest infer_request = compiled_model.create_infer_request();
5052

5153
// setup KV caches
5254
m_cache_manager = std::make_shared<CacheManager>(device_config, core);

src/cpp/src/image_generation/models/autoencoder_kl.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,14 @@ AutoencoderKL& AutoencoderKL::compile(const std::string& device, const ov::AnyMa
212212

213213
if (m_encoder_model) {
214214
ov::CompiledModel encoder_compiled_model = core.compile_model(m_encoder_model, device, properties);
215+
ov::genai::utils::print_compiled_model_properties(encoder_compiled_model, "Auto encoder KL encoder model");
215216
m_encoder_request = encoder_compiled_model.create_infer_request();
216217
// release the original model
217218
m_encoder_model.reset();
218219
}
219220

220221
ov::CompiledModel decoder_compiled_model = core.compile_model(m_decoder_model, device, properties);
222+
ov::genai::utils::print_compiled_model_properties(decoder_compiled_model, "Auto encoder KL decoder model");
221223
m_decoder_request = decoder_compiled_model.create_infer_request();
222224
// release the original model
223225
m_decoder_model.reset();

src/cpp/src/image_generation/models/clip_text_model.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ CLIPTextModel& CLIPTextModel::compile(const std::string& device, const ov::AnyMa
9797
} else {
9898
compiled_model = core.compile_model(m_model, device, properties);
9999
}
100+
ov::genai::utils::print_compiled_model_properties(compiled_model, "Clip Text model");
100101
m_request = compiled_model.create_infer_request();
101102
// release the original model
102103
m_model.reset();

src/cpp/src/image_generation/models/clip_text_model_with_projection.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ CLIPTextModelWithProjection& CLIPTextModelWithProjection::compile(const std::str
8888
} else {
8989
compiled_model = core.compile_model(m_model, device, properties);
9090
}
91+
ov::genai::utils::print_compiled_model_properties(compiled_model, "Clip Text with projection model");
9192
m_request = compiled_model.create_infer_request();
9293
// release the original model
9394
m_model.reset();

src/cpp/src/image_generation/models/flux_transformer_2d_model.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ FluxTransformer2DModel& FluxTransformer2DModel::reshape(int batch_size,
108108
FluxTransformer2DModel& FluxTransformer2DModel::compile(const std::string& device, const ov::AnyMap& properties) {
109109
OPENVINO_ASSERT(m_model, "Model has been already compiled. Cannot re-compile already compiled model");
110110
ov::CompiledModel compiled_model = utils::singleton_core().compile_model(m_model, device, properties);
111+
ov::genai::utils::print_compiled_model_properties(compiled_model, "Flux Transformer 2D model");
111112
m_request = compiled_model.create_infer_request();
112113
// release the original model
113114
m_model.reset();

src/cpp/src/image_generation/models/sd3_transformer_2d_model.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ SD3Transformer2DModel& SD3Transformer2DModel::reshape(int batch_size,
105105
SD3Transformer2DModel& SD3Transformer2DModel::compile(const std::string& device, const ov::AnyMap& properties) {
106106
OPENVINO_ASSERT(m_model, "Model has been already compiled. Cannot re-compile already compiled model");
107107
ov::CompiledModel compiled_model = utils::singleton_core().compile_model(m_model, device, properties);
108+
ov::genai::utils::print_compiled_model_properties(compiled_model, "SD3 Transformer 2D model");
108109
m_request = compiled_model.create_infer_request();
109110
// release the original model
110111
m_model.reset();

src/cpp/src/image_generation/models/t5_encoder_model.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ T5EncoderModel& T5EncoderModel::compile(const std::string& device, const ov::Any
6363
ov::Core core = utils::singleton_core();
6464
ov::CompiledModel compiled_model;
6565
compiled_model = core.compile_model(m_model, device, properties);
66+
ov::genai::utils::print_compiled_model_properties(compiled_model, "T5 encoder model");
6667
m_request = compiled_model.create_infer_request();
6768
// release the original model
6869
m_model.reset();

src/cpp/src/image_generation/models/unet_inference_dynamic.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class UNet2DConditionModel::UNetInferenceDynamic : public UNet2DConditionModel::
2020
ov::Core core = utils::singleton_core();
2121

2222
ov::CompiledModel compiled_model = core.compile_model(model, device, properties);
23+
ov::genai::utils::print_compiled_model_properties(compiled_model, "UNet 2D Condition dynamic model");
2324
m_request = compiled_model.create_infer_request();
2425
}
2526

src/cpp/src/image_generation/models/unet_inference_static_bs1.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class UNet2DConditionModel::UNetInferenceStaticBS1 : public UNet2DConditionModel
4040

4141
ov::Core core = utils::singleton_core();
4242
ov::CompiledModel compiled_model = core.compile_model(model, device, properties);
43+
ov::genai::utils::print_compiled_model_properties(compiled_model, "UNet 2D Condition batch-1 model");
4344

4445
for (int i = 0; i < m_native_batch_size; i++)
4546
{

src/cpp/src/llm_pipeline.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,21 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
7777
const ov::genai::GenerationConfig& generation_config
7878
) : LLMPipelineImplBase(tokenizer, generation_config) {
7979
ov::Core core;
80+
ov::CompiledModel compiled_model;
8081
auto [core_plugin_config, plugin_config] = ov::genai::utils::split_core_compile_config(config);
8182
utils::slice_matmul_statefull_model(model);
8283
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);
8384

8485
if (auto filtered_plugin_config = extract_adapters_from_properties(plugin_config, &m_generation_config.adapters)) {
8586
m_generation_config.adapters->set_tensor_name_prefix("base_model.model.model.");
8687
m_adapter_controller = AdapterController(model, *m_generation_config.adapters, device); // TODO: Make the prefix name configurable
87-
m_model_runner = core.compile_model(model, device, *filtered_plugin_config).create_infer_request();
88+
compiled_model = core.compile_model(model, device, *filtered_plugin_config);
89+
m_model_runner = compiled_model.create_infer_request();
8890
} else {
89-
m_model_runner = core.compile_model(model, device, plugin_config).create_infer_request();
91+
compiled_model = core.compile_model(model, device, plugin_config);
92+
m_model_runner = compiled_model.create_infer_request();
9093
}
94+
ov::genai::utils::print_compiled_model_properties(compiled_model, "Stateful LLM model");
9195

9296
// If eos_token_id was not provided, take value
9397
if (m_generation_config.eos_token_id == -1)

src/cpp/src/llm_pipeline_static.cpp

+8-5
Original file line numberDiff line numberDiff line change
@@ -777,12 +777,15 @@ void StaticLLMPipeline::setupAndCompileModels(
777777
set_npuw_cache_dir(prefill_config);
778778
set_npuw_cache_dir(generate_config);
779779

780-
m_kvcache_request = core.compile_model(
780+
auto kv_compiled_model = core.compile_model(
781781
kvcache_model, device, generate_config
782-
).create_infer_request();
783-
m_prefill_request = core.compile_model(
784-
prefill_model, device, prefill_config
785-
).create_infer_request();
782+
);
783+
ov::genai::utils::print_compiled_model_properties(kv_compiled_model, "Static LLM kv compiled model");
784+
m_kvcache_request = kv_compiled_model.create_infer_request();
785+
786+
auto prefill_compiled_model = core.compile_model(prefill_model, device, prefill_config);
787+
m_prefill_request = prefill_compiled_model.create_infer_request();
788+
ov::genai::utils::print_compiled_model_properties(prefill_compiled_model, "Static LLM prefill compiled model");
786789
}
787790

788791
void StaticLLMPipeline::setupAndImportModels(

src/cpp/src/lora_adapter.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,9 @@ class InferRequestSignatureCache {
637637

638638
ov::Core core = ov::genai::utils::singleton_core();
639639
auto model = std::make_shared<ov::Model>(request_results, request_parameters);
640-
rwb.request = core.compile_model(model, device).create_infer_request();
640+
auto compiled_model = core.compile_model(model, device);
641+
ov::genai::utils::print_compiled_model_properties(compiled_model, "Infer Request Signature Cache");
642+
rwb.request = compiled_model.create_infer_request();
641643
requests.emplace(signature, rwb);
642644
}
643645

src/cpp/src/tokenizer.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ class Tokenizer::TokenizerImpl {
203203
manager.register_pass<MakeCombineSegmentsSatateful>();
204204
manager.run_passes(ov_tokenizer);
205205
m_tokenizer = core.compile_model(ov_tokenizer, device, properties);
206+
ov::genai::utils::print_compiled_model_properties(m_tokenizer, "OV Tokenizer");
206207

207208
m_ireq_queue_tokenizer = std::make_unique<CircularBufferQueue<ov::InferRequest>>(
208209
m_tokenizer.get_property(ov::optimal_number_of_infer_requests),
@@ -216,6 +217,7 @@ class Tokenizer::TokenizerImpl {
216217
manager_detok.register_pass<MakeVocabDecoderSatateful>();
217218
manager_detok.run_passes(ov_detokenizer);
218219
m_detokenizer = core.compile_model(ov_detokenizer, device, properties);
220+
ov::genai::utils::print_compiled_model_properties(m_detokenizer, "OV Detokenizer");
219221

220222
m_ireq_queue_detokenizer = std::make_unique<CircularBufferQueue<ov::InferRequest>>(
221223
m_detokenizer.get_property(ov::optimal_number_of_infer_requests),

src/cpp/src/utils.cpp

+37
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,43 @@ void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t se
381381
}
382382
}
383383

384+
void print_compiled_model_properties(ov::CompiledModel& compiled_Model, const char* model_title) {
385+
// Specify the name of the environment variable
386+
const char* env_var_name = "OPENVINO_LOG_LEVEL";
387+
const char* env_var_value = std::getenv(env_var_name);
388+
389+
// Check if the environment variable was found
390+
if (env_var_value != nullptr && atoi(env_var_value) > static_cast<int>(ov::log::Level::WARNING)) {
391+
// output of the actual settings that the device selected
392+
auto supported_properties = compiled_Model.get_property(ov::supported_properties);
393+
std::cout << "Model: " << model_title << std::endl;
394+
for (const auto& cfg : supported_properties) {
395+
if (cfg == ov::supported_properties)
396+
continue;
397+
auto prop = compiled_Model.get_property(cfg);
398+
if (cfg == ov::device::properties) {
399+
auto devices_properties = prop.as<ov::AnyMap>();
400+
for (auto& item : devices_properties) {
401+
std::cout << " " << item.first << ": " << std::endl;
402+
for (auto& item2 : item.second.as<ov::AnyMap>()) {
403+
std::cout << " " << item2.first << ": " << item2.second.as<std::string>() << std::endl;
404+
}
405+
}
406+
} else {
407+
std::cout << " " << cfg << ": " << prop.as<std::string>() << std::endl;
408+
}
409+
}
410+
411+
ov::Core core;
412+
std::vector<std::string> exeTargets;
413+
exeTargets = compiled_Model.get_property(ov::execution_devices);
414+
std::cout << "EXECUTION_DEVICES:" << std::endl;
415+
for (const auto& device : exeTargets) {
416+
std::cout << " " << device << ": " << core.get_property(device, ov::device::full_name) << std::endl;
417+
}
418+
}
419+
}
420+
384421
} // namespace utils
385422
} // namespace genai
386423
} // namespace ov

src/cpp/src/utils.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ size_t get_seq_len_axis(std::shared_ptr<const ov::Model> model);
104104

105105
void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller);
106106

107+
void print_compiled_model_properties(ov::CompiledModel& compiled_Model, const char* model_title);
108+
107109
} // namespace utils
108110
} // namespace genai
109111
} // namespace ov

src/cpp/src/visual_language/embedding_model.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ EmbeddingsModel::EmbeddingsModel(const std::filesystem::path& model_dir,
2626
merge_postprocess(m_model, scale_emb);
2727

2828
ov::CompiledModel compiled_model = core.compile_model(m_model, device, properties);
29+
ov::genai::utils::print_compiled_model_properties(compiled_model, "text embeddings model");
2930
m_request = compiled_model.create_infer_request();
3031
}
3132

src/cpp/src/visual_language/inputs_embedder.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,10 @@ class InputsEmbedderMiniCPM : public InputsEmbedder::IInputsEmbedder {
259259
const std::string& device,
260260
const ov::AnyMap device_config) :
261261
IInputsEmbedder(vlm_config, model_dir, device, device_config) {
262-
m_resampler = utils::singleton_core().compile_model(
263-
model_dir / "openvino_resampler_model.xml", device, device_config
264-
).create_infer_request();
262+
auto compiled_model =
263+
utils::singleton_core().compile_model(model_dir / "openvino_resampler_model.xml", device, device_config);
264+
ov::genai::utils::print_compiled_model_properties(compiled_model, "VLM resampler model");
265+
m_resampler = compiled_model.create_infer_request();
265266

266267
m_pos_embed_cache = get_2d_sincos_pos_embed(m_vlm_config.hidden_size, {70, 70});
267268
}

src/cpp/src/visual_language/pipeline.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
9292
auto compiled_language_model = utils::singleton_core().compile_model(
9393
models_dir / "openvino_language_model.xml", device, properties
9494
);
95-
95+
ov::genai::utils::print_compiled_model_properties(compiled_language_model, "VLM language model");
9696
auto language_model = compiled_language_model.get_runtime_model();
9797
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(language_model);
9898

src/cpp/src/visual_language/vision_encoder.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -648,10 +648,12 @@ ov::Tensor get_pixel_values_internvl(const ov::Tensor& image, const ProcessorCon
648648

649649
VisionEncoder::VisionEncoder(const std::filesystem::path& model_dir, const VLMModelType model_type, const std::string& device, const ov::AnyMap device_config) :
650650
model_type(model_type) {
651-
m_vision_encoder = utils::singleton_core().compile_model(model_dir / "openvino_vision_embeddings_model.xml", device, device_config).create_infer_request();
652-
m_processor_config = utils::from_config_json_if_exists<ProcessorConfig>(
653-
model_dir, "preprocessor_config.json"
654-
);
651+
auto compiled_model = utils::singleton_core().compile_model(model_dir / "openvino_vision_embeddings_model.xml",
652+
device,
653+
device_config);
654+
ov::genai::utils::print_compiled_model_properties(compiled_model, "VLM vision embeddings model");
655+
m_vision_encoder = compiled_model.create_infer_request();
656+
m_processor_config = utils::from_config_json_if_exists<ProcessorConfig>(model_dir, "preprocessor_config.json");
655657
}
656658

657659
VisionEncoder::VisionEncoder(

src/cpp/src/whisper_pipeline.cpp

+12-9
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,18 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi
5656
auto [core_properties, compile_properties] = ov::genai::utils::split_core_compile_config(properties);
5757
core.set_property(core_properties);
5858

59-
m_models.encoder =
60-
core.compile_model((models_path / "openvino_encoder_model.xml").string(), device, compile_properties)
61-
.create_infer_request();
62-
m_models.decoder =
63-
core.compile_model((models_path / "openvino_decoder_model.xml").string(), device, compile_properties)
64-
.create_infer_request();
65-
m_models.decoder_with_past =
66-
core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, compile_properties)
67-
.create_infer_request();
59+
ov::CompiledModel compiled_model;
60+
compiled_model =
61+
core.compile_model((models_path / "openvino_encoder_model.xml").string(), device, compile_properties);
62+
ov::genai::utils::print_compiled_model_properties(compiled_model, "whisper encoder model");
63+
m_models.encoder = compiled_model.create_infer_request();
64+
compiled_model =
65+
core.compile_model((models_path / "openvino_decoder_model.xml").string(), device, compile_properties);
66+
ov::genai::utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
67+
m_models.decoder = compiled_model.create_infer_request();
68+
compiled_model = core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, compile_properties);
69+
m_models.decoder_with_past = compiled_model.create_infer_request();
70+
ov::genai::utils::print_compiled_model_properties(compiled_model, "whisper decoder with past model");
6871

6972
// If eos_token_id was not provided, take value
7073
if (m_generation_config.eos_token_id == -1) {

src/cpp/src/whisper_pipeline_static.cpp

+10-3
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,16 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
555555
preprocess_decoder(decoder_model);
556556
preprocess_decoder(decoder_with_past_model);
557557

558-
m_models.encoder = core.compile_model(encoder_model, "NPU").create_infer_request();
559-
m_models.decoder = core.compile_model(decoder_model, "NPU").create_infer_request();
560-
m_models.decoder_with_past = core.compile_model(decoder_with_past_model, "NPU").create_infer_request();
558+
ov::CompiledModel compiled_model;
559+
compiled_model = core.compile_model(encoder_model, "NPU");
560+
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model");
561+
m_models.encoder = compiled_model.create_infer_request();
562+
compiled_model = core.compile_model(decoder_model, "NPU");
563+
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
564+
m_models.decoder = compiled_model.create_infer_request();
565+
compiled_model = core.compile_model(decoder_with_past_model, "NPU");
566+
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model");
567+
m_models.decoder_with_past = compiled_model.create_infer_request();
561568

562569
// If eos_token_id was not provided, take value
563570
if (m_generation_config.eos_token_id == -1) {

src/docs/DEBUG_LOG.md

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
## 1. Using Debug Log
2+
3+
There are six levels of logs, which can be called explicitly or set via the ``OPENVINO_LOG_LEVEL`` environment variable:
4+
5+
0 - ``ov::log::Level::NO``
6+
1 - ``ov::log::Level::ERR``
7+
2 - ``ov::log::Level::WARNING``
8+
3 - ``ov::log::Level::INFO``
9+
4 - ``ov::log::Level::DEBUG``
10+
5 - ``ov::log::Level::TRACE``
11+
12+
When setting the environment variable OPENVINO_LOG_LEVEL > ov::log::Level::WARNING, the properties of the compiled model can be printed.
13+
14+
For example:
15+
16+
Linux - export OPENVINO_LOG_LEVEL=3
17+
Windows - set OPENVINO_LOG_LEVEL=3
18+
19+
the properties of the compiled model are printed as follows:
20+
```sh
21+
NETWORK_NAME: Model0
22+
OPTIMAL_NUMBER_OF_INFER_REQUESTS: 1
23+
NUM_STREAMS: 1
24+
INFERENCE_NUM_THREADS: 48
25+
PERF_COUNT: NO
26+
INFERENCE_PRECISION_HINT: bf16
27+
PERFORMANCE_HINT: LATENCY
28+
EXECUTION_MODE_HINT: PERFORMANCE
29+
PERFORMANCE_HINT_NUM_REQUESTS: 0
30+
ENABLE_CPU_PINNING: YES
31+
SCHEDULING_CORE_TYPE: ANY_CORE
32+
MODEL_DISTRIBUTION_POLICY:
33+
ENABLE_HYPER_THREADING: NO
34+
EXECUTION_DEVICES: CPU
35+
CPU_DENORMALS_OPTIMIZATION: NO
36+
LOG_LEVEL: LOG_NONE
37+
CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1
38+
DYNAMIC_QUANTIZATION_GROUP_SIZE: 32
39+
KV_CACHE_PRECISION: f16
40+
AFFINITY: CORE
41+
EXECUTION_DEVICES:
42+
CPU: Intel(R) Xeon(R) Platinum 8468
43+
```

0 commit comments

Comments
 (0)