Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add python bindings to VLMPipeline for encrypted models #1916

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add python bindings to VLMPipeline for encrypted models
olpipi committed Mar 17, 2025
commit 9f8a830bb301e64c4f162ade00e3364f15fcef5b
17 changes: 16 additions & 1 deletion samples/cpp/visual_language_chat/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -28,6 +28,21 @@ install(TARGETS visual_language_chat
COMPONENT samples_bin
EXCLUDE_FROM_ALL)

# create encrypted model sample executable

add_executable(encrypted_model_vlm encrypted_model_vlm.cpp load_image.cpp)
target_include_directories(encrypted_model_vlm PRIVATE "${CMAKE_CURRENT_SOUCE_DIR}" "${CMAKE_BINARY_DIR}")
target_link_libraries(encrypted_model_vlm PRIVATE openvino::genai)

set_target_properties(encrypted_model_vlm PROPERTIES
# Ensure out of box LC_RPATH on macOS with SIP
INSTALL_RPATH_USE_LINK_PATH ON)

install(TARGETS encrypted_model_vlm
RUNTIME DESTINATION samples_bin/
COMPONENT samples_bin
EXCLUDE_FROM_ALL)

# create benchmark executable

add_executable(benchmark_vlm benchmark_vlm.cpp load_image.cpp)
@@ -40,4 +55,4 @@ set_target_properties(benchmark_vlm PROPERTIES
install(TARGETS benchmark_vlm
RUNTIME DESTINATION samples_bin/
COMPONENT samples_bin
EXCLUDE_FROM_ALL)
EXCLUDE_FROM_ALL)
101 changes: 101 additions & 0 deletions samples/cpp/visual_language_chat/encrypted_model_vlm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <filesystem>
#include <fstream>

#include "load_image.hpp"
#include "openvino/genai/visual_language/pipeline.hpp"

std::pair<std::string, ov::Tensor> decrypt_model(const std::string& model_path, const std::string& weights_path) {
std::ifstream model_file(model_path);
std::ifstream weights_file(weights_path, std::ios::binary);
if (!model_file.is_open() || !weights_file.is_open()) {
throw std::runtime_error("Cannot open model or weights file");
}

// User can add file decryption of model_file and weights_file in memory here.


std::string model_str((std::istreambuf_iterator<char>(model_file)), std::istreambuf_iterator<char>());

weights_file.seekg(0, std::ios::end);
auto weight_size = static_cast<unsigned>(weights_file.tellg());
weights_file.seekg(0, std::ios::beg);
auto weights_tensor = ov::Tensor(ov::element::u8, {weight_size});
if (!weights_file.read(static_cast<char*>(weights_tensor.data()), weight_size)) {
throw std::runtime_error("Cannot read weights file");
}

return {model_str, weights_tensor};
}

ov::genai::Tokenizer decrypt_tokenizer(const std::string& models_path) {
std::string tok_model_path = models_path + "/openvino_tokenizer.xml";
std::string tok_weights_path = models_path + "/openvino_tokenizer.bin";
auto [tok_model_str, tok_weights_tensor] = decrypt_model(tok_model_path, tok_weights_path);

std::string detok_model_path = models_path + "/openvino_detokenizer.xml";
std::string detok_weights_path = models_path + "/openvino_detokenizer.bin";
auto [detok_model_str, detok_weights_tensor] = decrypt_model(detok_model_path, detok_weights_path);

return ov::genai::Tokenizer(tok_model_str, tok_weights_tensor, detok_model_str, detok_weights_tensor);
}


bool print_subword(std::string&& subword) {
return !(std::cout << subword << std::flush);
}

int main(int argc, char* argv[]) try {
if (4 != argc) {
throw std::runtime_error(std::string{"Usage "} + argv[0] + " <MODEL_DIR> <IMAGE_FILE OR DIR_WITH_IMAGES> <PROMPT>");
}

//read and encrypt models
std::string models_path = argv[1];
auto language_model = decrypt_model(models_path + "/openvino_language_model.xml", models_path + "/openvino_language_model.bin");
auto resampler_model = decrypt_model(models_path + "/openvino_resampler_model.xml", models_path + "/openvino_resampler_model.bin");
auto text_embeddings_model = decrypt_model(models_path + "/openvino_text_embeddings_model.xml", models_path + "/openvino_text_embeddings_model.bin");
auto vision_embeddings_model = decrypt_model(models_path + "/openvino_vision_embeddings_model.xml", models_path + "/openvino_vision_embeddings_model.bin");

ov::genai::ModelsMap models_map;
models_map.emplace("language", std::move(language_model));
models_map.emplace("resampler", std::move(resampler_model));
models_map.emplace("text_embeddings", std::move(text_embeddings_model));
models_map.emplace("vision_embeddings", std::move(vision_embeddings_model));
ov::genai::Tokenizer tokenizer = decrypt_tokenizer(models_path);

std::vector<ov::Tensor> rgbs = utils::load_images(argv[2]);

// GPU and NPU can be used as well.
// Note: If NPU selected, only language model will be run on NPU
std::string device = "CPU";
ov::AnyMap enable_compile_cache;
if (device == "GPU") {
// Cache compiled models on disk for GPU to save time on the
// next run. It's not beneficial for CPU.
enable_compile_cache.insert({ov::cache_dir("vlm_cache")});
}
ov::genai::VLMPipeline pipe(models_map, tokenizer, models_path, device, enable_compile_cache);

ov::genai::GenerationConfig generation_config;
generation_config.max_new_tokens = 100;

std::string prompt = argv[3];
pipe.generate(prompt,
ov::genai::images(rgbs),
ov::genai::generation_config(generation_config),
ov::genai::streamer(print_subword));

} catch (const std::exception& error) {
try {
std::cerr << error.what() << '\n';
} catch (const std::ios_base::failure&) {}
return EXIT_FAILURE;
} catch (...) {
try {
std::cerr << "Non-exception object thrown\n";
} catch (const std::ios_base::failure&) {}
return EXIT_FAILURE;
}
112 changes: 112 additions & 0 deletions samples/python/visual_language_chat/encrypted_model_vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#!/usr/bin/env python3
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse

import numpy as np
import openvino_genai
import openvino
from PIL import Image
from openvino import Tensor
from pathlib import Path
import typing

def decrypt_model(model_dir, model_file_name, weights_file_name):
with open(model_dir + '/' + model_file_name, "r") as file:
model = file.read()
# decrypt model

with open(model_dir + '/' + weights_file_name, "rb") as file:
binary_data = file.read()
# decrypt weights
weights = np.frombuffer(binary_data, dtype=np.uint8).astype(np.uint8)

return model, Tensor(weights)

def read_tokenizer(model_dir):
tokenizer_model_name = 'openvino_tokenizer.xml'
tokenizer_weights_name = 'openvino_tokenizer.bin'
tokenizer_model, tokenizer_weights = decrypt_model(model_dir, tokenizer_model_name, tokenizer_weights_name)

detokenizer_model_name = 'openvino_detokenizer.xml'
detokenizer_weights_name = 'openvino_detokenizer.bin'
detokenizer_model, detokenizer_weights = decrypt_model(model_dir, detokenizer_model_name, detokenizer_weights_name)

return openvino_genai.Tokenizer(tokenizer_model, tokenizer_weights, detokenizer_model, detokenizer_weights)

def streamer(subword: str) -> bool:
'''

Args:
subword: sub-word of the generated text.

Returns: Return flag corresponds whether generation should be stopped.

'''
print(subword, end='', flush=True)

# No value is returned as in this example we don't want to stop the generation in this method.
# "return None" will be treated the same as "return openvino_genai.StreamingStatus.RUNNING".


def read_image(path: str) -> Tensor:
'''

Args:
path: The path to the image.

Returns: the ov.Tensor containing the image.

'''
pic = Image.open(path).convert("RGB")
image_data = np.array(pic)
return Tensor(image_data)


def read_images(path: str) -> list[Tensor]:
entry = Path(path)
if entry.is_dir():
return [read_image(str(file)) for file in sorted(entry.iterdir())]
return [read_image(path)]

def main():
parser = argparse.ArgumentParser()
parser.add_argument('model_dir')
parser.add_argument('image_dir', help="Image file or dir with images")
parser.add_argument('prompt', help="Image file or dir with images")
args = parser.parse_args()

model_name_to_file_map = {
('language', 'openvino_language_model'),
('resampler', 'openvino_resampler_model'),
('text_embeddings', 'openvino_text_embeddings_model'),
('vision_embeddings', 'openvino_vision_embeddings_model')}

models_map = dict()
for model_name, file_name in model_name_to_file_map:
model, weights = decrypt_model(args.model_dir, file_name + '.xml', file_name + '.bin')
models_map[model_name] = (model, weights)

tokenizer = read_tokenizer(args.model_dir)

# GPU and NPU can be used as well.
# Note: If NPU selected, only language model will be run on NPU
device = 'CPU'
enable_compile_cache = dict()
if "GPU" == device:
# Cache compiled models on disk for GPU to save time on the
# next run. It's not beneficial for CPU.
enable_compile_cache["CACHE_DIR"] = "vlm_cache"

pipe = openvino_genai.VLMPipeline(models_map, tokenizer, args.model_dir, device, **enable_compile_cache)

config = openvino_genai.GenerationConfig()
config.max_new_tokens = 100

rgbs = read_images(args.image_dir)

pipe.generate(args.prompt, images=rgbs, generation_config=config, streamer=streamer)

if '__main__' == __name__:
main()
2 changes: 1 addition & 1 deletion src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
@@ -122,7 +122,7 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
) :
m_generation_config{generation_config} {
m_is_npu = device.find("NPU") != std::string::npos;
OPENVINO_ASSERT(m_is_npu,
OPENVINO_ASSERT(!m_is_npu,
"VLMPipeline initialization from string isn't supported for NPU device");

m_inputs_embedder = std::make_shared<InputsEmbedder>(models_map, tokenizer, config_dir_path, device, properties);
13 changes: 12 additions & 1 deletion src/python/openvino_genai/py_openvino_genai.pyi
Original file line number Diff line number Diff line change
@@ -2092,14 +2092,25 @@ class VLMPipeline:
"""
This class is used for generation with VLMs
"""
@typing.overload
def __init__(self, models_path: os.PathLike, device: str, **kwargs) -> None:
"""
device on which inference will be done
VLMPipeline class constructor.
models_path (os.PathLike): Path to the folder with exported model files.
device (str): Device to run the model on (e.g., CPU, GPU). Default is 'CPU'.
kwargs: Device properties
"""
@typing.overload
def __init__(self, models: typing.Dict[str, typing.Tuple[str, openvino.Tensor]], tokenizer: Tokenizer, config_dir_path: os.PathLike, device: str, generation_config : GenerationConfig | None = None, **kwargs) -> None:
"""
VLMPipeline class constructor.
models (typing.Dict[str, typing.Tuple[str, openvino.Tensor]]): Map with decrypted models. It should contains next models: language, resampler, text_embeddings, vision_embeddings.
tokenizer (Tokenizer): Genai Tokenizers.
config_dir_path (os.PathLike): Path to folder with model configs.
device (str): Device to run the model on (e.g., CPU, GPU). Default is 'CPU'.
generation_config (GenerationConfig | None): Device properties.
kwargs: Device properties
"""
def finish_chat(self) -> None:
...
@typing.overload
28 changes: 27 additions & 1 deletion src/python/py_vlm_pipeline.cpp
Original file line number Diff line number Diff line change
@@ -165,14 +165,40 @@ void init_vlm_pipeline(py::module_& m) {
return std::make_unique<ov::genai::VLMPipeline>(models_path, device, pyutils::kwargs_to_any_map(kwargs));
}),
py::arg("models_path"), "folder with exported model files",
py::arg("device"), "device on which inference will be done"
py::arg("device"), "device on which inference will be done",
R"(
VLMPipeline class constructor.
models_path (os.PathLike): Path to the folder with exported model files.
device (str): Device to run the model on (e.g., CPU, GPU). Default is 'CPU'.
kwargs: Device properties
)")

.def(py::init([](
const ov::genai::ModelsMap& models,
const ov::genai::Tokenizer& tokenizer,
const std::filesystem::path& config_dir_path,
const std::string& device,
const ov::genai::OptionalGenerationConfig& generation_config,
const py::kwargs& kwargs
) {
//return std::make_unique<ov::genai::VLMPipeline>(config_dir_path, device);
return std::make_unique<ov::genai::VLMPipeline>(models, tokenizer, config_dir_path, device, pyutils::kwargs_to_any_map(kwargs), generation_config.value_or(ov::genai::GenerationConfig()));
}),
py::arg("models"), "map with decrypted models",
py::arg("tokenizer"), "genai Tokenizers",
py::arg("config_dir_path"), "Path to folder with model configs",
py::arg("device"), "device on which inference will be done",
py::arg("generation_config") = std::nullopt, "generation config",
R"(
VLMPipeline class constructor.
models (typing.Dict[str, typing.Tuple[str, openvino.Tensor]]): Map with decrypted models. It should contains next models: language, resampler, text_embeddings, vision_embeddings.
tokenizer (Tokenizer): Genai Tokenizers.
config_dir_path (os.PathLike): Path to folder with model configs.
device (str): Device to run the model on (e.g., CPU, GPU). Default is 'CPU'.
generation_config (GenerationConfig | None): Device properties.
kwargs: Device properties
)")

.def("start_chat", &ov::genai::VLMPipeline::start_chat, py::arg("system_message") = "")
.def("finish_chat", &ov::genai::VLMPipeline::finish_chat)
.def("set_chat_template", &ov::genai::VLMPipeline::set_chat_template, py::arg("chat_template"))
56 changes: 56 additions & 0 deletions tests/python_tests/samples/test_encrypted_model_vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import pytest
import sys

from conftest import SAMPLES_PY_DIR, SAMPLES_CPP_DIR
from test_utils import run_sample

def generate_images(path):
from PIL import Image
import numpy as np
import requests
res = 28, 28
lines = np.arange(res[0] * res[1] * 3, dtype=np.uint8) % 255
lines = lines.reshape([*res, 3])
lines_image = Image.fromarray(lines)
cat = Image.open(requests.get("https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11", stream=True).raw).convert('RGB')

lines_image_path = path + "/lines.png"
cat_path = path + "/cat.png"
lines_image.save(lines_image_path)
cat.save(cat_path)
yield lines_image_path, cat_path

os.remove(lines_image_path)
os.remove(cat_path)

class TestEncryptedVLM:
@pytest.mark.llm
@pytest.mark.samples
@pytest.mark.parametrize("convert_model", ["tiny-random-minicpmv-2_6"], indirect=True)
@pytest.mark.parametrize("sample_args", ["Describe the images."])

def test_sample_encrypted_lm(self, convert_model, sample_args, tmp_path):
generate_images(tmp_path)

# Test Python sample
py_script = os.path.join(SAMPLES_PY_DIR, "visual_language_chat/encrypted_model_vlm.py")
py_command = [sys.executable, py_script, convert_model, tmp_path, sample_args]
py_result = run_sample(py_command)

# Test CPP sample
cpp_sample = os.path.join(SAMPLES_CPP_DIR, 'encrypted_model_vlm')
cpp_command =[cpp_sample, convert_model, tmp_path, sample_args]
cpp_result = run_sample(cpp_command)

# Test common sample
py_common_script = os.path.join(SAMPLES_PY_DIR, "visual_language_chat/visual_language_chat.py")
py_common_command = [sys.executable, py_common_script, convert_model, tmp_path]
py_common_result = run_sample(py_common_command, sample_args)

# Compare results
assert py_result.stdout == cpp_result.stdout, f"Results should match"
assert py_result.stdout == py_common_result.stdout, f"Results should match"