Skip to content

Commit cbb1fa0

Browse files
popovaanWovchena
andauthored
VLM performance metrics. (openvinotoolkit#1263)
VLM performance metrics. Ticket: CVS-156661 --------- Co-authored-by: Vladimir Zlobin <vladimir.zlobin@intel.com>
1 parent 6f160e0 commit cbb1fa0

File tree

17 files changed

+589
-48
lines changed

17 files changed

+589
-48
lines changed

.github/workflows/causal_lm_cpp.yml

+14-1
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ jobs:
727727
ov_link: ${{ env.l_u22_ov_link }}
728728
- uses: ./.github/actions/build_app
729729
with:
730-
build_target: 'visual_language_chat py_openvino_genai'
730+
build_target: 'visual_language_chat benchmark_vlm py_openvino_genai'
731731
- uses: ./.github/actions/install_python_deps
732732
- name: Download and convert tiny-random-minicpmv-2_6 model and an image
733733
run: |
@@ -754,6 +754,12 @@ jobs:
754754
&& ./build/samples/cpp/visual_language_chat/visual_language_chat ./tiny-random-minicpmv-2_6/ ./images/
755755
<<< $'Describe the images?' | tee cpp.txt
756756
timeout-minutes: 2
757+
- name: Run benchmark_vlm C++ sample - tiny-random-minicpmv-2_6
758+
run: >
759+
set -o pipefail
760+
&& source ./ov/setupvars.sh
761+
&& ./build/samples/cpp/visual_language_chat/benchmark_vlm -m ./tiny-random-minicpmv-2_6/ -i ./images/cat.png -n 3
762+
timeout-minutes: 2
757763
- name: Run visual_language_chat Python sample - tiny-random-minicpmv-2_6
758764
run: >
759765
set -o pipefail
@@ -762,6 +768,13 @@ jobs:
762768
<<< $'Describe the images?' | tee py.txt
763769
env:
764770
PYTHONPATH: "./build/"
771+
- name: Run benchmark_vlm Python sample - tiny-random-minicpmv-2_6
772+
run: >
773+
set -o pipefail
774+
&& source ./ov/setupvars.sh
775+
&& ./samples/python/visual_language_chat/benchmark_vlm.py -m ./tiny-random-minicpmv-2_6/ -i ./images/cat.png -n 3
776+
env:
777+
PYTHONPATH: "./build/"
765778
- name: Encode cpp.txt with Python encoding instead of terminal one
766779
shell: python
767780
run: |

samples/cpp/visual_language_chat/CMakeLists.txt

+17
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ file(DOWNLOAD
1313
${CMAKE_BINARY_DIR}/stb_image.h
1414
EXPECTED_HASH MD5=27932e6fb3a2f26aee2fc33f2cb4e696)
1515

16+
# create main sample executable
17+
1618
add_executable(visual_language_chat visual_language_chat.cpp load_image.cpp)
1719
target_include_directories(visual_language_chat PRIVATE "${CMAKE_CURRENT_SOUCE_DIR}" "${CMAKE_BINARY_DIR}")
1820
target_link_libraries(visual_language_chat PRIVATE openvino::genai)
@@ -26,3 +28,18 @@ install(TARGETS visual_language_chat
2628
RUNTIME DESTINATION samples_bin/
2729
COMPONENT samples_bin
2830
EXCLUDE_FROM_ALL)
31+
32+
# create benchmark executable
33+
34+
add_executable(benchmark_vlm benchmark_vlm.cpp load_image.cpp)
35+
target_include_directories(benchmark_vlm PRIVATE "${CMAKE_CURRENT_SOUCE_DIR}" "${CMAKE_BINARY_DIR}")
36+
target_link_libraries(benchmark_vlm PRIVATE openvino::genai cxxopts::cxxopts)
37+
set_target_properties(benchmark_vlm PROPERTIES
38+
COMPILE_PDB_NAME benchmark_vlm
39+
# Ensure out of box LC_RPATH on macOS with SIP
40+
INSTALL_RPATH_USE_LINK_PATH ON)
41+
42+
install(TARGETS benchmark_vlm
43+
RUNTIME DESTINATION samples_bin/
44+
COMPONENT samples_bin
45+
EXCLUDE_FROM_ALL)

samples/cpp/visual_language_chat/README.md

+41
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
This example showcases inference of Visual language models (VLMs): [`openbmb/MiniCPM-V-2_6`](https://huggingface.co/openbmb/MiniCPM-V-2_6). The application doesn't have many configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `ov::genai::VLMPipeline` and runs the simplest deterministic greedy sampling algorithm. There is also a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/minicpm-v-multimodal-chatbot) which provides an example of Visual-language assistant.
44

5+
6+
There are two sample files:
7+
- [`visual_language_chat.cpp`](./visual_language_chat.cpp) demonstrates basic usage of the VLM pipeline.
8+
- [`benchmark_vlm.cpp`](./benchmark_vlm.cpp) shows how to benchmark a VLM in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text and calculating various performance metrics.
9+
10+
511
## Download and convert the model and tokenizers
612

713
The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.
@@ -25,6 +31,41 @@ Discrete GPUs (dGPUs) usually provide better performance compared to CPUs. It is
2531

2632
See [SUPPORTED_MODELS.md](../../../src/docs/SUPPORTED_MODELS.md#visual-language-models) for the list of supported models.
2733

34+
## Run benchmark:
35+
36+
```sh
37+
benchmark_vlm [OPTIONS]
38+
```
39+
40+
### Options
41+
42+
- `-m, --model`(default: `.`): Path to the model and tokenizers base directory.
43+
- `-p, --prompt` (default: `What is on the image?`): The prompt to generate text.
44+
- `-i, --image` (default: `image.jpg`): Path to the image.
45+
- `-nw, --num_warmup` (default: `1`): Number of warmup iterations.
46+
- `-mt, --max_new_tokens` (default: `20`): Number of warmup iterations.
47+
- `-n, --num_iter` (default: `3`): Number of iterations.
48+
- `-d, --device` (default: `"CPU"`): Device to run the model on.
49+
50+
### Output:
51+
52+
```
53+
benchmark_vlm -m miniCPM-V-2_6 -i 319483352-d5fbbd1a-d484-415c-88cb-9986625b7b11.jpg -n 3
54+
```
55+
56+
```
57+
Load time: 1982.00 ms
58+
Generate time: 13820.99 ± 64.62 ms
59+
Tokenization time: 1.26 ± 0.09 ms
60+
Detokenization time: 0.33 ± 0.05 ms
61+
Embeddings preparation time: 5733.85 ± 26.34 ms
62+
TTFT: 11246.98 ± 80.55 ms
63+
TPOT: 135.45 ± 4.73 ms/token
64+
Throughput: 7.38 ± 0.26 tokens/s
65+
```
66+
67+
For more information how performance metrics are calculated please follow [performance-metrics tutorial](../../../src/README.md#performance-metrics).
68+
2869
### Troubleshooting
2970

3071
#### Unicode characters encoding error on Windows
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include <cxxopts.hpp>
5+
#include <filesystem>
6+
7+
#include "load_image.hpp"
8+
#include <openvino/genai/visual_language/pipeline.hpp>
9+
10+
11+
int main(int argc, char* argv[]) try {
12+
cxxopts::Options options("benchmark_vlm", "Help command");
13+
14+
options.add_options()
15+
("m,model", "Path to model and tokenizers base directory", cxxopts::value<std::string>()->default_value("."))
16+
("p,prompt", "Prompt", cxxopts::value<std::string>()->default_value("What is on the image?"))
17+
("i,image", "Image", cxxopts::value<std::string>()->default_value("image.jpg"))
18+
("nw,num_warmup", "Number of warmup iterations", cxxopts::value<size_t>()->default_value(std::to_string(1)))
19+
("n,num_iter", "Number of iterations", cxxopts::value<size_t>()->default_value(std::to_string(3)))
20+
("mt,max_new_tokens", "Maximal number of new tokens", cxxopts::value<size_t>()->default_value(std::to_string(20)))
21+
("d,device", "device", cxxopts::value<std::string>()->default_value("CPU"))
22+
("h,help", "Print usage");
23+
24+
cxxopts::ParseResult result;
25+
try {
26+
result = options.parse(argc, argv);
27+
} catch (const cxxopts::exceptions::exception& e) {
28+
std::cout << e.what() << "\n\n";
29+
std::cout << options.help() << std::endl;
30+
return EXIT_FAILURE;
31+
}
32+
33+
if (result.count("help")) {
34+
std::cout << options.help() << std::endl;
35+
return EXIT_SUCCESS;
36+
}
37+
38+
std::string prompt = result["prompt"].as<std::string>();
39+
const std::string models_path = result["model"].as<std::string>();
40+
const std::string image_path = result["image"].as<std::string>();
41+
std::string device = result["device"].as<std::string>();
42+
size_t num_warmup = result["num_warmup"].as<size_t>();
43+
size_t num_iter = result["num_iter"].as<size_t>();
44+
ov::Tensor image = utils::load_image(image_path);
45+
46+
ov::genai::GenerationConfig config;
47+
config.max_new_tokens = result["max_new_tokens"].as<size_t>();
48+
49+
ov::genai::VLMPipeline pipe(models_path, device);
50+
51+
for (size_t i = 0; i < num_warmup; i++)
52+
pipe.generate(prompt, ov::genai::image(image), ov::genai::generation_config(config));
53+
54+
auto res = pipe.generate(prompt, ov::genai::image(image), ov::genai::generation_config(config));
55+
auto metrics = res.perf_metrics;
56+
for (size_t i = 0; i < num_iter - 1; i++) {
57+
res = pipe.generate(prompt, ov::genai::image(image), ov::genai::generation_config(config));
58+
metrics = metrics + res.perf_metrics;
59+
}
60+
61+
std::cout << std::fixed << std::setprecision(2);
62+
std::cout << "Load time: " << metrics.get_load_time() << " ms" << std::endl;
63+
std::cout << "Generate time: " << metrics.get_generate_duration().mean << " ± " << metrics.get_generate_duration().std << " ms" << std::endl;
64+
std::cout << "Tokenization time: " << metrics.get_tokenization_duration().mean << " ± " << metrics.get_tokenization_duration().std << " ms" << std::endl;
65+
std::cout << "Detokenization time: " << metrics.get_detokenization_duration().mean << " ± " << metrics.get_detokenization_duration().std << " ms" << std::endl;
66+
std::cout << "Embeddings preparation time: " << metrics.get_prepare_embeddings_duration().mean << " ± " << metrics.get_prepare_embeddings_duration().std << " ms" << std::endl;
67+
std::cout << "TTFT: " << metrics.get_ttft().mean << " ± " << metrics.get_ttft().std << " ms" << std::endl;
68+
std::cout << "TPOT: " << metrics.get_tpot().mean << " ± " << metrics.get_tpot().std << " ms/token " << std::endl;
69+
std::cout << "Throughput: " << metrics.get_throughput().mean << " ± " << metrics.get_throughput().std << " tokens/s" << std::endl;
70+
71+
return 0;
72+
} catch (const std::exception& error) {
73+
try {
74+
std::cerr << error.what() << '\n';
75+
} catch (const std::ios_base::failure&) {}
76+
return EXIT_FAILURE;
77+
} catch (...) {
78+
try {
79+
std::cerr << "Non-exception object thrown\n";
80+
} catch (const std::ios_base::failure&) {}
81+
return EXIT_FAILURE;
82+
}

samples/python/visual_language_chat/README.md

+39
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
This example showcases inference of text-generation Vision Language Models (VLMs): `miniCPM-V-2_6` and other models with the same signature. The application doesn't have many configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `openvino_genai.VLMPipeline` and configures it for the chat scenario. There is also a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/minicpm-v-multimodal-chatbot) which provides an example of Visual-language assistant.
44

5+
There are two sample files:
6+
- [`visual_language_chat.py`](./visual_language_chat.py) demonstrates basic usage of the VLM pipeline.
7+
- [`benchmark_vlm.py`](./benchmark_vlm.py) shows how to benchmark a VLM in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text and calculating various performance metrics.
8+
59
## Download and convert the model and tokenizers
610

711
The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.
@@ -27,6 +31,41 @@ Modify the source code to change the device for inference to the GPU.
2731

2832
See https://github.com/openvinotoolkit/openvino.genai/blob/master/src/README.md#supported-models for the list of supported models.
2933

34+
## Run benchmark:
35+
36+
```sh
37+
python benchmark_vlm.py [OPTIONS]
38+
```
39+
40+
### Options
41+
42+
- `-m, --model`(default: `.`): Path to the model and tokenizers base directory.
43+
- `-p, --prompt` (default: `What is on the image?`): The prompt to generate text.
44+
- `-i, --image` (default: `image.jpg`): Path to the image.
45+
- `-nw, --num_warmup` (default: `1`): Number of warmup iterations.
46+
- `-mt, --max_new_tokens` (default: `20`): Number of warmup iterations.
47+
- `-n, --num_iter` (default: `3`): Number of iterations.
48+
- `-d, --device` (default: `"CPU"`): Device to run the model on.
49+
50+
### Output:
51+
52+
```
53+
python benchmark_vlm.py -m miniCPM-V-2_6 -i 319483352-d5fbbd1a-d484-415c-88cb-9986625b7b11.jpg -n 3
54+
```
55+
56+
```
57+
Load time: 1982.00 ms
58+
Generate time: 13820.99 ± 64.62 ms
59+
Tokenization time: 1.26 ± 0.09 ms
60+
Detokenization time: 0.33 ± 0.05 ms
61+
Embeddings preparation time: 5733.85 ± 26.34 ms
62+
TTFT: 11246.98 ± 80.55 ms
63+
TPOT: 135.45 ± 4.73 ms/token
64+
Throughput: 7.38 ± 0.26 tokens/s
65+
```
66+
67+
For more information how performance metrics are calculated please follow [performance-metrics tutorial](../../../src/README.md#performance-metrics).
68+
3069
### Troubleshooting
3170

3271
#### Unicode characters encoding error on Windows
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#!/usr/bin/env python3
2+
# Copyright (C) 2023-2024 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import argparse
6+
import openvino_genai as ov_genai
7+
from PIL import Image
8+
from openvino import Tensor
9+
import numpy as np
10+
11+
12+
def read_image(path: str) -> Tensor:
13+
'''
14+
15+
Args:
16+
path: The path to the image.
17+
18+
Returns: the ov.Tensor containing the image.
19+
20+
'''
21+
pic = Image.open(path).convert("RGB")
22+
image_data = np.array(pic.getdata()).reshape(1, pic.size[1], pic.size[0], 3).astype(np.uint8)
23+
return Tensor(image_data)
24+
25+
26+
def main():
27+
parser = argparse.ArgumentParser(description="Help command")
28+
parser.add_argument("-m", "--model", type=str, help="Path to model and tokenizers base directory")
29+
parser.add_argument("-p", "--prompt", type=str, default="The Sky is blue because", help="Prompt")
30+
parser.add_argument("-i", "--image", type=str, default="image.jpg", help="Image")
31+
parser.add_argument("-nw", "--num_warmup", type=int, default=1, help="Number of warmup iterations")
32+
parser.add_argument("-n", "--num_iter", type=int, default=2, help="Number of iterations")
33+
parser.add_argument("-mt", "--max_new_tokens", type=int, default=20, help="Maximal number of new tokens")
34+
parser.add_argument("-d", "--device", type=str, default="CPU", help="Device")
35+
36+
args = parser.parse_args()
37+
38+
# Perf metrics is stored in VLMDecodedResults.
39+
# In order to get VLMDecodedResults instead of a string input should be a list.
40+
prompt = args.prompt
41+
models_path = args.model
42+
image = read_image(args.image)
43+
device = args.device
44+
num_warmup = args.num_warmup
45+
num_iter = args.num_iter
46+
47+
config = ov_genai.GenerationConfig()
48+
config.max_new_tokens = args.max_new_tokens
49+
50+
pipe = ov_genai.VLMPipeline(models_path, device)
51+
52+
for _ in range(num_warmup):
53+
pipe.generate(prompt, images=image, generation_config=config)
54+
55+
res = pipe.generate(prompt, images=image, generation_config=config)
56+
perf_metrics = res.perf_metrics
57+
for _ in range(num_iter - 1):
58+
res = pipe.generate(prompt, images=image, generation_config=config)
59+
perf_metrics += res.perf_metrics
60+
61+
print(f"Load time: {perf_metrics.get_load_time():.2f} ms")
62+
print(
63+
f"Generate time: {perf_metrics.get_generate_duration().mean:.2f} ± {perf_metrics.get_generate_duration().std:.2f} ms")
64+
print(
65+
f"Tokenization time: {perf_metrics.get_tokenization_duration().mean:.2f} ± {perf_metrics.get_tokenization_duration().std:.2f} ms")
66+
print(
67+
f"Detokenization time: {perf_metrics.get_detokenization_duration().mean:.2f} ± {perf_metrics.get_detokenization_duration().std:.2f} ms")
68+
print(
69+
f"Embeddings preparation time: {perf_metrics.get_prepare_embeddings_duration().mean:.2f} ± {perf_metrics.get_prepare_embeddings_duration().std:.2f} ms")
70+
print(f"TTFT: {perf_metrics.get_ttft().mean:.2f} ± {perf_metrics.get_ttft().std:.2f} ms")
71+
print(f"TPOT: {perf_metrics.get_tpot().mean:.2f} ± {perf_metrics.get_tpot().std:.2f} ms")
72+
print(f"Throughput : {perf_metrics.get_throughput().mean:.2f} ± {perf_metrics.get_throughput().std:.2f} tokens/s")
73+
74+
75+
if __name__ == "__main__":
76+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#pragma once
5+
6+
#include "openvino/genai/perf_metrics.hpp"
7+
#include "openvino/genai/visibility.hpp"
8+
9+
10+
namespace ov::genai {
11+
12+
struct OPENVINO_GENAI_EXPORTS VLMRawPerfMetrics {
13+
/** @brief Duration of preparation of embeddings */
14+
std::vector<MicroSeconds> prepare_embeddings_durations;
15+
};
16+
17+
struct OPENVINO_GENAI_EXPORTS VLMPerfMetrics : public PerfMetrics {
18+
/** @brief Mean and standard deviation of preparation of embeddings in milliseconds */
19+
MeanStdPair prepare_embeddings_duration;
20+
21+
MeanStdPair get_prepare_embeddings_duration();
22+
23+
VLMPerfMetrics() = default;
24+
25+
VLMPerfMetrics(PerfMetrics& perf_metrics) : PerfMetrics(perf_metrics){};
26+
27+
void evaluate_statistics(std::optional<TimePoint> start_time = std::nullopt) override;
28+
29+
VLMPerfMetrics operator+(const VLMPerfMetrics& metrics) const;
30+
31+
VLMRawPerfMetrics vlm_raw_metrics;
32+
};
33+
34+
}

0 commit comments

Comments
 (0)