Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend VLM to run LM on NPU #1783

Merged
merged 20 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
17d9269
Enable LM part of VLM to work on NPU
TolyaTalamanov Mar 3, 2025
cefb3af
Merge branch 'master' into at/vlm-pipeline
TolyaTalamanov Mar 4, 2025
6ac899d
Add test and clean up
TolyaTalamanov Mar 4, 2025
52b60b2
Merge branch 'at/vlm-pipeline' of https://github.com/TolyaTalamanov/o…
TolyaTalamanov Mar 4, 2025
5737b0f
Update src/cpp/src/visual_language/pipeline.cpp
TolyaTalamanov Mar 4, 2025
2ed3a76
Update src/cpp/src/visual_language/pipeline.cpp
TolyaTalamanov Mar 4, 2025
ec0aad3
Update src/cpp/src/visual_language/pipeline.cpp
TolyaTalamanov Mar 4, 2025
23c1316
Merge branch 'at/vlm-pipeline' of https://github.com/TolyaTalamanov/o…
TolyaTalamanov Mar 4, 2025
ada0a05
Add tests for NPU VLM
TolyaTalamanov Mar 4, 2025
74ce19c
Comment sample
TolyaTalamanov Mar 4, 2025
db537a2
Merge branch 'master' into at/vlm-pipeline
ilya-lavrenov Mar 4, 2025
570c96c
Change vlm test for NPU
TolyaTalamanov Mar 5, 2025
6168bd7
Add comment about NPU into python VLM sample
TolyaTalamanov Mar 5, 2025
0c8b236
Merge branch 'at/vlm-pipeline' of https://github.com/TolyaTalamanov/o…
TolyaTalamanov Mar 5, 2025
f4fdfef
Merge branch 'master' into at/vlm-pipeline
TolyaTalamanov Mar 5, 2025
700e16d
Update test_vlm_pipeline.py
TolyaTalamanov Mar 5, 2025
37936ae
Update test_vlm_pipeline.py
TolyaTalamanov Mar 5, 2025
fa12185
Update test_vlm_pipeline.py
TolyaTalamanov Mar 5, 2025
9cf0c5f
Merge branch 'master' into at/vlm-pipeline
TolyaTalamanov Mar 5, 2025
4c85465
Apply suggestions from code review
ilya-lavrenov Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion samples/cpp/visual_language_chat/visual_language_chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ int main(int argc, char* argv[]) try {

std::vector<ov::Tensor> rgbs = utils::load_images(argv[2]);

std::string device = "CPU"; // GPU can be used as well
// GPU and NPU can be used as well.
// Note: If NPU selected, only language model will be run on NPU
std::string device = "CPU";
ov::AnyMap enable_compile_cache;
if (device == "GPU") {
// Cache compiled models on disk for GPU to save time on the
Expand Down
4 changes: 3 additions & 1 deletion samples/python/visual_language_chat/visual_language_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def main():

rgbs = read_images(args.image_dir)

device = 'CPU' # GPU can be used as well
# GPU and NPU can be used as well.
# Note: If NPU selected, only language model will be run on NPU
device = 'CPU'
enable_compile_cache = dict()
if "GPU" == device:
# Cache compiled models on disk for GPU to save time on the
Expand Down
1 change: 0 additions & 1 deletion src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ ov::genai::LLMPipeline::LLMPipeline(
const std::string& device,
const ov::AnyMap& user_properties) {
auto start_time = std::chrono::steady_clock::now();

auto [properties, attention_backend] = extract_attention_backend(user_properties);

// If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
if (m_is_npu) {
utils::KVDesc kv_desc;
std::tie(compiled_model, kv_desc) = utils::compile_decoder_for_npu(
model, *filtered_properties, kv_pos, models_path
model, *filtered_properties, kv_pos, models_path / "openvino_model.xml"
);
m_max_kv_cache_size = kv_desc.max_prompt_len + kv_desc.min_response_len;
} else {
Expand Down
4 changes: 3 additions & 1 deletion src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ StatefulLLMPipeline::StatefulLLMPipeline(
) : LLMPipelineImplBase(tokenizer, generation_config),
m_sampler(m_tokenizer) {
auto kv_pos = ov::genai::utils::get_kv_axes_pos(model);
auto [compiled, kv_desc] = utils::compile_decoder_for_npu(model, properties, kv_pos, models_path);
auto [compiled, kv_desc] = utils::compile_decoder_for_npu(
model, properties, kv_pos, models_path / "openvino_model.xml"
);
m_max_prompt_len = kv_desc.max_prompt_len;
m_kvcache_total = kv_desc.max_prompt_len + kv_desc.min_response_len;
m_request = compiled.create_infer_request();
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ compile_decoder_for_npu(const std::shared_ptr<ov::Model>& model,
properties[ov::cache_mode.name()] = CacheMode::OPTIMIZE_SPEED;
compiled = ov::genai::utils::singleton_core().compile_model(model, "NPU", properties);
} else {
compiled = ov::genai::utils::singleton_core().compile_model(model_path / "openvino_model.xml", "NPU", properties);
compiled = ov::genai::utils::singleton_core().compile_model(model_path, "NPU", properties);
}
// Also export compiled model if required
if (export_blob) {
Expand Down
85 changes: 65 additions & 20 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
size_t m_kv_cache_seq_length_axis = 2;
// Component for applying sampling to lm outputs
Sampler m_sampler;
size_t m_max_kv_cache_size = std::numeric_limits<size_t>::max();
bool m_is_npu = false;
public:
VLMPipelineImpl(
const std::filesystem::path& models_dir,
Expand All @@ -55,22 +57,52 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
models_dir, "generation_config.json"
)
} {
m_inputs_embedder = std::make_shared<InputsEmbedder>(models_dir, device, properties);

m_tokenizer = m_inputs_embedder->get_tokenizer();
m_embedding = m_inputs_embedder->get_embedding_model();

auto compiled_language_model = utils::singleton_core().compile_model(
models_dir / "openvino_language_model.xml", device, properties
m_is_npu = device.find("NPU") != std::string::npos;

auto properties_copy = properties;
auto language_model_path = models_dir / "openvino_language_model.xml";
auto language_model = utils::singleton_core().read_model(language_model_path, {}, properties_copy);
auto kv_pos = ov::genai::utils::get_kv_axes_pos(language_model);
m_kv_cache_seq_length_axis = kv_pos.seq_len;

// In case user provided properties per-device
// {
// ov::device::properties("NPU", ...),
// ov::device::properties("CPU", ...)
// }
auto device_propertes = utils::pop_or_default<ov::AnyMap>(
properties_copy, ov::device::properties.name(), { }
);
utils::print_compiled_model_properties(compiled_language_model, "VLM language model");
auto language_model = compiled_language_model.get_runtime_model();
m_kv_cache_seq_length_axis = utils::get_kv_axes_pos(language_model).seq_len;
// Otherwise, the same properties are used for all models and devices
auto lm_properties = device_propertes.empty()
? properties_copy
: utils::pop_or_default<ov::AnyMap>(device_propertes, device, {});

ov::CompiledModel compiled_language_model;
auto embedder_device = device;
if (m_is_npu) {
embedder_device = "CPU";
utils::KVDesc kv_desc;
std::tie(compiled_language_model, kv_desc) = utils::compile_decoder_for_npu(
language_model, lm_properties, kv_pos, language_model_path
);
m_max_kv_cache_size = kv_desc.max_prompt_len + kv_desc.min_response_len;
} else {
compiled_language_model = utils::singleton_core().compile_model(language_model, device, lm_properties);
}
ov::genai::utils::print_compiled_model_properties(compiled_language_model, "VLM language model");

m_language = compiled_language_model.create_infer_request();

m_kv_cache_seq_length_axis = utils::get_kv_axes_pos(language_model).seq_len;
m_language.get_tensor("attention_mask").set_shape({1, 0});

auto embedder_properties = device_propertes.empty()
? properties_copy
: utils::pop_or_default<ov::AnyMap>(device_propertes, embedder_device, {});
m_inputs_embedder = std::make_shared<InputsEmbedder>(models_dir, embedder_device, embedder_properties);
m_tokenizer = m_inputs_embedder->get_tokenizer();
m_embedding = m_inputs_embedder->get_embedding_model();

// If eos_token_id was not provided, take value
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
Expand All @@ -80,7 +112,7 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
m_sampler.set_seed(m_generation_config.rng_seed);
}


VLMPipelineImpl(
const ModelsMap& models_map,
const Tokenizer& tokenizer,
Expand All @@ -90,6 +122,10 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
const GenerationConfig& generation_config
) :
m_generation_config{generation_config} {
m_is_npu = device.find("NPU") != std::string::npos;
OPENVINO_ASSERT(m_is_npu,
"VLMPipeline initialization from string isn't supported for NPU device");

m_inputs_embedder = std::make_shared<InputsEmbedder>(models_map, tokenizer, config_dir_path, device, properties);

m_tokenizer = m_inputs_embedder->get_tokenizer();
Expand Down Expand Up @@ -136,6 +172,14 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
generation_config.set_eos_token_id(m_generation_config.eos_token_id);
generation_config.validate();

if (m_is_npu) {
OPENVINO_ASSERT(rgbs.size() == 1u, "Currently only batch size equal to 1 is supported for NPU device!");
OPENVINO_ASSERT(generation_config.is_greedy_decoding() || generation_config.is_multinomial(),
"Currently only greedy and multinomial decoding are supported for NPU device!");
OPENVINO_ASSERT(generation_config.num_return_sequences == 1u,
"Currently only \"num_return_sequences\" equal to 1 is supported for NPU device!");
}

m_inputs_embedder->set_apply_chat_template_status(generation_config.apply_chat_template);

auto start_get_inputs_embeds = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -179,9 +223,8 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
m_sampler.set_seed(generation_config.rng_seed);
}

utils::GenerationFinishInfo finish_info = get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests,
position_ids, kv_cache_state, m_embedding, rope_delta);

ov::genai::utils::GenerationFinishInfo finish_info = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests,
position_ids, kv_cache_state, m_embedding, rope_delta, m_max_kv_cache_size);
EncodedResults& encoded_result = finish_info.results;

auto decode_start_time = std::chrono::steady_clock::now();
Expand All @@ -208,7 +251,7 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
res_raw_counters.generate_durations.emplace_back(PerfMetrics::get_microsec(generate_end_time - generate_start_time));
res_raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_microsec(decode_end_time - decode_start_time));
res_raw_counters.tokenization_durations.insert(res_raw_counters.tokenization_durations.end(), raw_counters.tokenization_durations.begin(), raw_counters.tokenization_durations.end());

// VLM specific perf metrics
decoded.perf_metrics.vlm_raw_metrics.prepare_embeddings_durations.emplace_back(PerfMetrics::get_microsec(end_get_inputs_embeds - start_get_inputs_embeds));

Expand All @@ -220,6 +263,7 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
}

void start_chat(const std::string& system_message) override {
OPENVINO_ASSERT(!m_is_npu, "start_chat() isn't supported in VLMPipeline for NPU device");
m_is_chat_conversation = true;
bool have_state = 0 != m_language.get_tensor("attention_mask").get_size();
if (have_state) {
Expand All @@ -232,6 +276,7 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
}

void finish_chat() override {
OPENVINO_ASSERT(!m_is_npu, "finish_chat() isn't supported in VLMPipeline for NPU device");
m_is_chat_conversation = false;
// Resetting state may be slow.
m_language.reset_state();
Expand Down Expand Up @@ -276,8 +321,8 @@ VLMPipeline::VLMPipeline(
) {
auto start_time = std::chrono::steady_clock::now();

if (properties.find(scheduler_config.name()) != properties.end() ||
properties.find(utils::DRAFT_MODEL_ARG_NAME) != properties.end() ||
if (properties.find(scheduler_config.name()) != properties.end() ||
properties.find(utils::DRAFT_MODEL_ARG_NAME) != properties.end() ||
properties.find(prompt_lookup.name()) != properties.end()) {
auto [plugin_config, scheduler_config] = utils::extract_scheduler_config(properties);
m_pimpl = std::make_unique<VLMContinuousBatchingAdapter>(models_dir, scheduler_config, device, plugin_config);
Expand All @@ -298,8 +343,8 @@ VLMPipeline::VLMPipeline(
const GenerationConfig& generation_config
) {
auto start_time = std::chrono::steady_clock::now();
if (properties.find(scheduler_config.name()) != properties.end() ||
properties.find(utils::DRAFT_MODEL_ARG_NAME) != properties.end() ||
if (properties.find(scheduler_config.name()) != properties.end() ||
properties.find(utils::DRAFT_MODEL_ARG_NAME) != properties.end() ||
properties.find(prompt_lookup.name()) != properties.end()) {
auto [plugin_config, scheduler_config] = utils::extract_scheduler_config(properties);
m_pimpl = std::make_unique<VLMContinuousBatchingAdapter>(models_map, tokenizer, config_dir_path, scheduler_config, device, plugin_config, generation_config);
Expand Down
32 changes: 31 additions & 1 deletion tests/python_tests/test_vlm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import openvino_tokenizers
import openvino
import pytest
import platform
import sys
import transformers
from optimum.intel.openvino import OVModelForVisualCausalLM
from openvino_genai import VLMPipeline, GenerationConfig, SchedulerConfig, ContinuousBatchingPipeline, GenerationStatus
Expand Down Expand Up @@ -92,7 +94,7 @@ def streamer(word: str) -> bool:
images = []
for link in links:
images.append(get_image_by_link(link))

result_from_streamer = []
res = ov_pipe.generate(prompts[0], images=images, generation_config=generation_config, streamer=streamer)
assert res.texts[0] == ''.join(result_from_streamer)
Expand Down Expand Up @@ -328,3 +330,31 @@ def test_perf_metrics(cache):
mean_dur, std_dur = perf_metrics.get_prepare_embeddings_duration()
assert np.allclose(mean_dur, np.mean(raw_dur))
assert np.allclose(std_dur, np.std(raw_dur))


@pytest.mark.precommit
@pytest.mark.nightly
# FIXME: katuni4ka/tiny-random-qwen2vl - fails on NPU
@pytest.mark.parametrize("model_id", model_ids[:-1])
@pytest.mark.skipif(
sys.platform == "darwin" or platform.machine() in ["aarch64", "arm64", "ARM64"],
reason="NPU plugin is available only on Linux and Windows x86_64",
)
def test_vlm_npu_no_exception(model_id, cache):
models_path = get_ov_model(model_ids[0], cache)
properties = {
"DEVICE_PROPERTIES":
{
"NPU": { "NPUW_DEVICES": "CPU", "NPUW_ONLINE_PIPELINE": "NONE" }
}
}

ov_pipe = VLMPipeline(models_path, "NPU", config=properties)

generation_config = ov_pipe.get_generation_config()
generation_config.max_new_tokens = 30
generation_config.set_eos_token_id(ov_pipe.get_tokenizer().get_eos_token_id())

for link in image_links_for_testing[2]:
image = get_image_by_link(link)
out = ov_pipe.generate(prompts[0], images=[image], generation_config=generation_config)
Loading