From 5eb59ea0e4a839f1306fd03c00b37eb76a413b34 Mon Sep 17 00:00:00 2001 From: Pavel Esir <pavel.esir@intel.com> Date: Thu, 16 May 2024 21:07:05 +0200 Subject: [PATCH 01/40] add tests --- .github/workflows/causal_lm_cpp.yml | 4 +- .github/workflows/genai_python_lib.yml | 12 +- .../cpp/generate_pipeline => src}/README.md | 92 ++++++++------- .../include/openvino/genai/llm_pipeline.hpp | 25 +++- src/cpp/include/openvino/genai/tokenizer.hpp | 2 +- src/cpp/src/generation_config.cpp | 99 ++++++++-------- src/cpp/src/llm_pipeline.cpp | 48 +++++--- src/cpp/src/tokenizer.cpp | 15 ++- src/cpp/src/utils.cpp | 14 ++- src/cpp/src/utils.hpp | 39 ++++++- src/python/openvino_genai/__init__.py | 4 + src/python/py_generate_pipeline.cpp | 21 +++- src/tests/python_tests/test_greedy.py | 29 ----- tests/python_tests/list_test_models.py | 23 ++++ tests/python_tests/requirements.txt | 3 + tests/python_tests/test_generate_api.py | 110 ++++++++++++++++++ text_generation/causal_lm/cpp/CMakeLists.txt | 8 +- .../{generate_pipeline => }/chat_sample.cpp | 0 .../cpp/generate_pipeline/generate_sample.cpp | 94 --------------- .../causal_lm/cpp/greedy_causal_lm.cpp | 1 + 20 files changed, 387 insertions(+), 256 deletions(-) rename {text_generation/causal_lm/cpp/generate_pipeline => src}/README.md (69%) delete mode 100644 src/tests/python_tests/test_greedy.py create mode 100644 tests/python_tests/list_test_models.py create mode 100644 tests/python_tests/requirements.txt create mode 100644 tests/python_tests/test_generate_api.py rename text_generation/causal_lm/cpp/{generate_pipeline => }/chat_sample.cpp (100%) delete mode 100644 text_generation/causal_lm/cpp/generate_pipeline/generate_sample.cpp diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml index 23d9006d07..a07dacac30 100644 --- a/.github/workflows/causal_lm_cpp.yml +++ b/.github/workflows/causal_lm_cpp.yml @@ -194,8 +194,8 @@ jobs: shell: cmd run: | call w_openvino_toolkit_windows_2024.1.0.15008.f4afc983258_x86_64\setupvars.bat - - .\build\Release\beam_search_causal_lm.exe .\TinyLlama-1.1B-Chat-v1.0\ "69" > .\pred.txt + .\build\text_generation\causal_lm\cpp\Release\beam_search_causal_lm.exe .\TinyLlama-1.1B-Chat-v1.0\ "69" > .\pred.txt + echo import transformers > ref.py echo predictions = open('pred.txt', 'r').read() >> ref.py echo tokenizer = transformers.LlamaTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0') >> ref.py diff --git a/.github/workflows/genai_python_lib.yml b/.github/workflows/genai_python_lib.yml index e9cfefff31..6697ba934a 100644 --- a/.github/workflows/genai_python_lib.yml +++ b/.github/workflows/genai_python_lib.yml @@ -18,7 +18,17 @@ jobs: - run: python -m pip install --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly # Can't load CentOS libraries from the archive - run: PYTHONPATH=./src/python/ python -c "from openvino_genai.py_generate_pipeline import LLMPipeline" - run: source ./ov/setupvars.sh && python -m pip install --pre . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly - - run: python -c "from openvino_genai.py_generate_pipeline import LLMPipeline" + - run: python -c "from openvino_genai import LLMPipeline" + - name: Install optimum-cli and run for each model + run: | + cd ./tests/ + python -m pip install -r requirements.txt + models=$(python3 generate_models.py) + echo "$models" | while read -r model_name model_path; do + echo "Processing model: $model_name at $model_path" + optimum-cli export openvino --trust-remote-code --weight-format fp16 --model "$model_name" "$model_path" + done + python -m pytest test_generate_api.py windows_genai_python_lib: runs-on: windows-latest diff --git a/text_generation/causal_lm/cpp/generate_pipeline/README.md b/src/README.md similarity index 69% rename from text_generation/causal_lm/cpp/generate_pipeline/README.md rename to src/README.md index 0a0f6010e6..ad21250989 100644 --- a/text_generation/causal_lm/cpp/generate_pipeline/README.md +++ b/src/README.md @@ -2,7 +2,7 @@ ## Usage -Firs of all you need to convert your model with optimum-cli +First of all you need to convert your model with optimum-cli ``` sh optimum-cli export openvino --model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" --weight-format fp16 --trust-remote-code "TinyLlama-1.1B-Chat-v1.0" pip install openvino-genai @@ -10,19 +10,33 @@ pip install openvino-genai LLMPipeline is the main object used for decoding. You can initiliza it straigh away from the folder with the converted model. It will automanically load the main model, tokenizer, detokenizer and default generation configuration. -### In Python +### Python A minimalist example: ```python -import py_generate_pipeline as genai # set more friendly module name -pipe = genai.LLMPipeline(model_path, "CPU") +import openvino_genai as ov_genai +pipe = ov_genai.LLMPipeline(model_path, "CPU") print(pipe.generate("The Sun is yellow bacause")) ``` +Calling generate with custom generation config parameters, e.g. config for grouped beam search +```python +import openvino_genai as ov_genai +pipe = ov_genai.LLMPipeline(model_path, "CPU") + +res = pipe.generate("The Sun is yellow bacause", max_new_tokens=30, num_groups=3, group_size=5) +print(res) +``` + +output: +``` +'it is made up of carbon atoms. The carbon atoms are arranged in a linear pattern, which gives the yellow color. The arrangement of carbon atoms in' +``` + A simples chat in python: ```python import openvino_genai as ov_genai -pipe = ov_genai.LLMPipeline(model_path) +pipe = ov_ov_genai.LLMPipeline(model_path) config = {'num_groups': 3, 'group_size': 5, 'diversity_penalty': 1.1} pipe.set_generation_cofnig(config) @@ -39,60 +53,45 @@ pipe.finish_chat() ``` Test to compare with Huggingface outputs -```python -from transformers import AutoTokenizer, AutoModelForCausalLM - -tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") -model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") - -max_new_tokens = 32 -prompt = 'table is made of' - -encoded_prompt = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False) -hf_encoded_output = model.generate(encoded_prompt, max_new_tokens=max_new_tokens, do_sample=False) -hf_output = tokenizer.decode(hf_encoded_output[0, encoded_prompt.shape[1]:]) -print(f'hf_output: {hf_output}') - -import sys -sys.path.append('build-Debug/') -import py_generate_pipeline as genai # set more friendly module name - -pipe = genai.LLMPipeline('text_generation/causal_lm/TinyLlama-1.1B-Chat-v1.0/pytorch/dldt/FP16/') -ov_output = pipe(prompt, max_new_tokens=max_new_tokens) -print(f'ov_output: {ov_output}') -assert hf_output == ov_output - -``` - -### In C++ +### C++ Minimalistc example ```cpp +#include "openvino/genai/llm_pipeline.hpp" +#include <iostream> + int main(int argc, char* argv[]) { std::string model_path = argv[1]; ov::LLMPipeline pipe(model_path, "CPU"); - cout << pipe.generate("The Sun is yellow bacause"); + std::cout << pipe.generate("The Sun is yellow bacause"); } ``` Using Group Beam Search Decoding ```cpp +#include "openvino/genai/llm_pipeline.hpp" +#include <iostream> + int main(int argc, char* argv[]) { std::string model_path = argv[1]; ov::LLMPipeline pipe(model_path, "CPU"); + ov::GenerationConfig config = pipe.get_generation_config(); config.max_new_tokens = 256; config.num_groups = 3; config.group_size = 5; config.diversity_penalty = 1.0f; - cout << pipe.generate("The Sun is yellow bacause", config); + std::cout << pipe.generate("The Sun is yellow bacause", config); } ``` A simplest chat in C++ ``` cpp +#include "openvino/genai/llm_pipeline.hpp" +#include <iostream> + int main(int argc, char* argv[]) { std::string prompt; @@ -142,24 +141,38 @@ int main(int argc, char* argv[]) { Streaming exapmle with lambda function ``` cpp -int main(int argc, char* argv[]) { - auto streamer = [](std::string word) { std::cout << word << std::flush; }; +#include "openvino/genai/llm_pipeline.hpp" +#include <iostream> + +int main(int argc, char* argv[]) { std::string model_path = argv[1]; ov::LLMPipeline pipe(model_path, "CPU"); - cout << pipe.generate("The Sun is yellow bacause", streamer); + + auto streamer = [](std::string word) { std::cout << word << std::flush; }; + std::cout << pipe.generate("The Sun is yellow bacause", streamer); } ``` Streaming with custom class ``` cpp #include <streamer_base.hpp> +#include "openvino/genai/llm_pipeline.hpp" +#include <iostream> class CustomStreamer: publict StreamerBase { public: - void put(int64_t token) {/* decode tokens and do process them*/}; - - void end() {/* decode tokens and do process them*/}; + void put(int64_t token) { + /* custom decoding/tokens processing code + tokens_cache.push_back(token); + std::string text = m_tokenizer.decode(tokens_cache); + ... + */ + }; + + void end() { + /* custom finalization */ + }; }; int main(int argc, char* argv[]) { @@ -170,4 +183,3 @@ int main(int argc, char* argv[]) { cout << pipe.generate("The Sun is yellow bacause", custom_streamer); } ``` - diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index b25d11ecd4..2a6e53eea6 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -39,6 +39,24 @@ class DecodedResults { public: std::vector<std::string> texts; std::vector<float> scores; + + // @brief Convert DecodedResults to a vector of strings. + // @return A std::vector<std::string> containing the texts from the DecodedResults object. + operator std::vector<std::string>() const { + return texts; + } + + // @brief Overloads operator<< to enhance output the contents of DecodedResults. + // @return A reference to the output stream with the concatenated texts. + friend std::ostream& operator<<(std::ostream& os, const DecodedResults& dr) { + for (size_t i = 0; i < dr.texts.size(); ++i) { + os << dr.texts[i]; + if (i != dr.texts.size() - 1) { + os << std::endl; + } + } + return os; + } }; /** @@ -53,7 +71,9 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline { * @param device optional device * @param plugin_config optional plugin_config */ - LLMPipeline(std::string& path, std::string device="CPU", const ov::AnyMap& plugin_config={}); + LLMPipeline(std::string& path, std::string device="CPU", + const ov::AnyMap& plugin_config={}, + const std::string& ov_tokenizer_path=""); /** * @brief Constructs a LLMPipeline when ov::Tokenizer is initialized manually using file from the different dirs. @@ -67,7 +87,8 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline { const std::string model_path, const ov::Tokenizer& tokenizer, const std::string device="CPU", - const ov::AnyMap& plugin_config = {} + const ov::AnyMap& plugin_config = {}, + const std::string& ov_tokenizer_path="" ); ~LLMPipeline(); diff --git a/src/cpp/include/openvino/genai/tokenizer.hpp b/src/cpp/include/openvino/genai/tokenizer.hpp index 0d55d9b0fe..07bfe96d44 100644 --- a/src/cpp/include/openvino/genai/tokenizer.hpp +++ b/src/cpp/include/openvino/genai/tokenizer.hpp @@ -21,7 +21,7 @@ class OPENVINO_GENAI_EXPORTS Tokenizer { * @param tokenizer_path openvino_tokenizer.xml and openvino_detokenizer.xml should be located in the tokenizer_path * @param device device. Currently only 'CPU' is supported */ - Tokenizer(const std::string& tokenizers_path, const std::string& device="CPU"); + Tokenizer(const std::string& tokenizers_path, const std::string& device="CPU", const std::string& ov_tokenizer_path=""); /** * @brief encode a single prompt diff --git a/src/cpp/src/generation_config.cpp b/src/cpp/src/generation_config.cpp index b392e44b3b..f7cbcaa075 100644 --- a/src/cpp/src/generation_config.cpp +++ b/src/cpp/src/generation_config.cpp @@ -10,40 +10,44 @@ #include "openvino/genai/generation_config.hpp" #include "generation_config_helper.hpp" +#include "utils.hpp" + +namespace { + + +} // namespace + namespace ov { GenerationConfig::GenerationConfig(std::string json_path) { + using ov::generate_utils::read_json_param; + std::ifstream f(json_path); OPENVINO_ASSERT(f.is_open(), "Failed to open '" + json_path + "' with generation config"); nlohmann::json data = nlohmann::json::parse(f); - - if (data.contains("max_new_tokens")) max_new_tokens = data["max_new_tokens"]; - if (data.contains("max_length")) max_length = data["max_length"]; + + read_json_param(data, "max_new_tokens", max_new_tokens); + read_json_param(data, "max_length", max_length); // note that ignore_eos is not present in HF GenerationConfig - if (data.contains("num_beam_groups")) num_beam_groups = data["num_beam_groups"]; - if (data.contains("num_beams")) num_beams = data["num_beams"]; - if (data.contains("diversity_penalty")) diversity_penalty = data["diversity_penalty"]; - if (data.contains("length_penalty")) length_penalty = data["length_penalty"]; - if (data.contains("num_return_sequences")) num_return_sequences = data["num_return_sequences"]; - if (data.contains("no_repeat_ngram_size")) no_repeat_ngram_size = data["no_repeat_ngram_size"]; + read_json_param(data, "num_beam_groups", num_beam_groups); + read_json_param(data, "num_beams", num_beams); + read_json_param(data, "diversity_penalty", diversity_penalty); + read_json_param(data, "length_penalty", length_penalty); + read_json_param(data, "num_return_sequences", num_return_sequences); + read_json_param(data, "no_repeat_ngram_size", no_repeat_ngram_size); // stop_criteria will be processed below - if (data.contains("temperature")) temperature = data["temperature"]; - if (data.contains("top_p")) top_p = data["top_p"]; - if (data.contains("top_k")) top_k = data["top_k"]; - if (data.contains("do_sample")) do_sample = data["do_sample"]; - if (data.contains("repetition_penalty")) repetition_penalty = data["repetition_penalty"]; - if (data.contains("pad_token_id")) pad_token_id = data["pad_token_id"]; - if (data.contains("bos_token_id")) bos_token_id = data["bos_token_id"]; - - if (data.contains("eos_token_id") && data["eos_token_id"].type() == nlohmann::json::value_t::number_integer) { - // todo: qwen contains several eos_token_id - eos_token_id = data["eos_token_id"]; - } - - if (data.contains("bos_token")) bos_token = data["bos_token"]; - if (data.contains("eos_token")) eos_token = data["eos_token"]; + read_json_param(data, "temperature", temperature); + read_json_param(data, "top_p", top_p); + read_json_param(data, "top_k", top_k); + read_json_param(data, "do_sample", do_sample); + read_json_param(data, "repetition_penalty", repetition_penalty); + read_json_param(data, "pad_token_id", pad_token_id); + read_json_param(data, "bos_token_id", bos_token_id); + read_json_param(data, "eos_token_id", eos_token_id); + read_json_param(data, "bos_token", bos_token); + read_json_param(data, "eos_token", eos_token); if (data.contains("early_stopping")) { auto field_type = data["early_stopping"].type(); @@ -55,32 +59,35 @@ GenerationConfig::GenerationConfig(std::string json_path) { stop_criteria = StopCriteria::heuristic; } } + + } GenerationConfig GenerationConfigHelper::anymap_to_generation_config(const ov::AnyMap& config_map) { + using ov::generate_utils::read_anymap_param; + GenerationConfig config = m_config; - - if (config_map.count("max_new_tokens")) config.max_new_tokens = config_map.at("max_new_tokens").as<size_t>(); - if (config_map.count("max_length")) config.max_length = config_map.at("max_length").as<size_t>(); - if (config_map.count("ignore_eos")) config.ignore_eos = config_map.at("ignore_eos").as<bool>(); - if (config_map.count("num_beam_groups")) config.num_beam_groups = config_map.at("num_beam_groups").as<size_t>(); - if (config_map.count("num_beams")) config.num_beams = config_map.at("num_beams").as<size_t>(); - if (config_map.count("diversity_penalty")) config.diversity_penalty = config_map.at("diversity_penalty").as<float>(); - if (config_map.count("length_penalty")) config.length_penalty = config_map.at("length_penalty").as<float>(); - if (config_map.count("num_return_sequences")) config.num_return_sequences = config_map.at("num_return_sequences").as<size_t>(); - if (config_map.count("no_repeat_ngram_size")) config.no_repeat_ngram_size = config_map.at("no_repeat_ngram_size").as<size_t>(); - if (config_map.count("stop_criteria")) config.stop_criteria = config_map.at("stop_criteria").as<StopCriteria>(); - if (config_map.count("temperature")) config.temperature = config_map.at("temperature").as<float>(); - if (config_map.count("top_p")) config.top_p = config_map.at("top_p").as<float>(); - if (config_map.count("top_k")) config.top_k = config_map.at("top_k").as<int>(); - if (config_map.count("do_sample")) config.do_sample = config_map.at("do_sample").as<bool>(); - if (config_map.count("repetition_penalty")) config.repetition_penalty = config_map.at("repetition_penalty").as<float>(); - if (config_map.count("pad_token_id")) config.pad_token_id = config_map.at("pad_token_id").as<int64_t>(); - if (config_map.count("bos_token_id")) config.bos_token_id = config_map.at("bos_token_id").as<int64_t>(); - if (config_map.count("eos_token_id")) config.eos_token_id = config_map.at("eos_token_id").as<int64_t>(); - if (config_map.count("bos_token")) config.bos_token = config_map.at("bos_token").as<std::string>(); - if (config_map.count("eos_token")) config.eos_token = config_map.at("eos_token").as<std::string>(); - + read_anymap_param(config_map, "max_new_tokens", config.max_new_tokens); + read_anymap_param(config_map, "max_length", config.max_length); + read_anymap_param(config_map, "ignore_eos", config.ignore_eos); + read_anymap_param(config_map, "num_beam_groups", config.num_beam_groups); + read_anymap_param(config_map, "num_beams", config.num_beams); + read_anymap_param(config_map, "diversity_penalty", config.diversity_penalty); + read_anymap_param(config_map, "length_penalty", config.length_penalty); + read_anymap_param(config_map, "num_return_sequences", config.num_return_sequences); + read_anymap_param(config_map, "no_repeat_ngram_size", config.no_repeat_ngram_size); + read_anymap_param(config_map, "stop_criteria", config.stop_criteria); + read_anymap_param(config_map, "temperature", config.temperature); + read_anymap_param(config_map, "top_p", config.top_p); + read_anymap_param(config_map, "top_k", config.top_k); + read_anymap_param(config_map, "do_sample", config.do_sample); + read_anymap_param(config_map, "repetition_penalty", config.repetition_penalty); + read_anymap_param(config_map, "pad_token_id", config.pad_token_id); + read_anymap_param(config_map, "bos_token_id", config.bos_token_id); + read_anymap_param(config_map, "eos_token_id", config.eos_token_id); + read_anymap_param(config_map, "bos_token", config.bos_token); + read_anymap_param(config_map, "eos_token", config.eos_token); + return config; } diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 9d4161f859..c47e54dd48 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -44,10 +44,12 @@ class LLMPipeline::LLMPipelineImpl { const std::string model_path, const ov::Tokenizer& tokenizer, const std::string device, - const ov::AnyMap& plugin_config + const ov::AnyMap& plugin_config, + const std::string& ov_tokenizer_path="" ); LLMPipelineImpl(std::string& path, std::string device, const ov::AnyMap& config); + LLMPipelineImpl(std::string& path, std::string device, const ov::AnyMap& config, const std::string& ov_tokenizer_path=""); GenerationConfig generation_config() const; @@ -68,7 +70,8 @@ ov::LLMPipeline::LLMPipeline( const std::string model_path, const ov::Tokenizer& tokenizer, const std::string device, - const ov::AnyMap& plugin_config + const ov::AnyMap& plugin_config, + const std::string& ov_tokenizer_path ) { m_pimpl = make_unique<LLMPipelineImpl>(model_path, tokenizer, device, plugin_config); } @@ -77,7 +80,8 @@ ov::LLMPipeline::LLMPipelineImpl::LLMPipelineImpl( const std::string model_path, const ov::Tokenizer& tokenizer, std::string device, - const ov::AnyMap& plugin_config + const ov::AnyMap& plugin_config, + const std::string& ov_tokenizer_path ): m_tokenizer(tokenizer), m_device(device), m_plugin_config(plugin_config) { ov::Core core; @@ -91,25 +95,37 @@ ov::LLMPipeline::LLMPipelineImpl::LLMPipelineImpl( } } -ov::LLMPipeline::LLMPipeline(std::string& path, std::string device, const ov::AnyMap& config) { - m_pimpl = make_unique<LLMPipelineImpl>(path, device, config); +ov::LLMPipeline::LLMPipeline(std::string& path, std::string device, const ov::AnyMap& config, const std::string& ov_tokenizer_path) { + m_pimpl = make_unique<LLMPipelineImpl>(path, device, config, ov_tokenizer_path); } -ov::LLMPipeline::LLMPipelineImpl::LLMPipelineImpl(std::string& path, std::string device, const ov::AnyMap& config) { - std::string tokenizer_config_fname = "tokenizer_config.json"; - std::string generation_config_fname = "generation_config.json"; +ov::LLMPipeline::LLMPipelineImpl::LLMPipelineImpl(std::string& path, std::string device, + const ov::AnyMap& config, const std::string& ov_tokenizer_path) { + std::string config_path = path + "/" + "config.json"; + std::string tokenizer_config_path = path + "/" +"tokenizer_config.json"; + std::string generation_config_path = path + "/" +"generation_config.json"; + + if (std::filesystem::exists(generation_config_path)) { + m_generation_config = GenerationConfig(generation_config_path); + } else if (std::filesystem::exists(config_path)) { + // some models (e.g. google/gemma-*) do not have generation_config.json, but have config.json + // and special tokens are stored there. + + std::ifstream f(config_path); + OPENVINO_ASSERT(f.is_open(), "Failed to open '" + config_path + "' with config.json"); - if (std::filesystem::exists(path + "/" + generation_config_fname)) { - m_generation_config = GenerationConfig(path + "/" + generation_config_fname); - } - if (std::filesystem::exists(path + "/" + tokenizer_config_fname)) { - std::ifstream f(path + "/" + tokenizer_config_fname); nlohmann::json data = nlohmann::json::parse(f); - m_chat_template = data.value("chat_template", ""); + using ov::generate_utils::read_json_param; + read_json_param(data, "pad_token_id", m_generation_config.pad_token_id); + read_json_param(data, "bos_token_id", m_generation_config.bos_token_id); + read_json_param(data, "eos_token_id", m_generation_config.eos_token_id); + } + + if (std::filesystem::exists(tokenizer_config_path)) { + std::ifstream f(tokenizer_config_path); + ov::generate_utils::read_json_param(nlohmann::json::parse(f), "chat_template", m_chat_template); } - - m_device = device; ov::Core core; diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp index 75c18734d3..a7df42e2e0 100644 --- a/src/cpp/src/tokenizer.cpp +++ b/src/cpp/src/tokenizer.cpp @@ -53,15 +53,18 @@ class Tokenizer::TokenizerImpl { int64_t m_eos_token_id = 2; TokenizerImpl() = default; - TokenizerImpl(std::string tokenizers_path, const std::string device) { + TokenizerImpl(std::string tokenizers_path, const std::string device, const std::string& ov_tokenizer_path) { ov::Core core; if (ov::generate_utils::is_xml(tokenizers_path)) OPENVINO_THROW("tokenizers_path should be a path to a dir not a xml file"); - // todo:: OPENVINO_TOKENIZERS_PATH is defined in CMakeLists.txt - core.add_extension(OPENVINO_TOKENIZERS_PATH); - + if (ov_tokenizer_path.empty()) { + // OPENVINO_TOKENIZERS_PATH is defined in CMakeLists.txt + core.add_extension(OPENVINO_TOKENIZERS_PATH); + } else { + core.add_extension(ov_tokenizer_path + "/libopenvino_tokenizers.so"); + } std::shared_ptr<ov::Model> tokenizer_model, detokenizer_model; try { tokenizer_model = core.read_model(tokenizers_path + "/openvino_tokenizer.xml"); @@ -141,8 +144,8 @@ class Tokenizer::TokenizerImpl { } }; -Tokenizer::Tokenizer(const std::string& tokenizers_path, const std::string& device) { - m_pimpl = std::make_shared<TokenizerImpl>(tokenizers_path, device); +Tokenizer::Tokenizer(const std::string& tokenizers_path, const std::string& device, const std::string& ov_tokenizer_path) { + m_pimpl = std::make_shared<TokenizerImpl>(tokenizers_path, device, ov_tokenizer_path); } std::pair<ov::Tensor, ov::Tensor> Tokenizer::encode(const std::string prompt) { diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 92df3d7067..dbd18cf3f3 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -18,11 +18,15 @@ void print_tensor(const ov::Tensor& tensor) { auto t_shape = tensor.get_shape(); std::cout << "["; - for (size_t i = 0; i < t_shape[1]; ++i) { - if (tensor.get_element_type() == ov::element::i64) { - res.emplace_back(tensor.data<int64_t>()[i]); - std::cout << tensor.data<int64_t>()[i] << " "; + for (size_t i = 0; i < t_shape[0]; ++i) { + std::cout << "|"; + for (size_t j = 0; j < t_shape[1]; ++j) { + if (tensor.get_element_type() == ov::element::i64) { + res.emplace_back(tensor.data<int64_t>()[t_shape[1] * i + j]); + std::cout << tensor.data<int64_t>()[t_shape[1] * i + j] << " "; + } } + std::cout << "|"; } std::cout << "]" << std::endl; } @@ -132,4 +136,4 @@ ov::Tensor extend_attention(ov::Tensor attention_mask) { } } // namespace generate_utils -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 7510c59e46..d7998a9594 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -4,6 +4,7 @@ #pragma once #include <openvino/openvino.hpp> +#include <nlohmann/json.hpp> namespace ov { namespace generate_utils { @@ -22,5 +23,41 @@ void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention bool is_xml(const std::string& path); +template <typename> +struct json_type_traits {}; + +template <> +struct json_type_traits<int> { static constexpr auto json_value_t = nlohmann::json::value_t::number_integer; }; + +template <> +struct json_type_traits<int64_t> { static constexpr auto json_value_t = nlohmann::json::value_t::number_integer; }; + +template <> +struct json_type_traits<size_t> { static constexpr auto json_value_t = nlohmann::json::value_t::number_unsigned; }; + +template <> +struct json_type_traits<float> { static constexpr auto json_value_t = nlohmann::json::value_t::number_float; }; + +template <> +struct json_type_traits<std::string> { static constexpr auto json_value_t = nlohmann::json::value_t::string; }; + +template <> +struct json_type_traits<bool> { static constexpr auto json_value_t = nlohmann::json::value_t::boolean; }; + +template <typename T> +void read_json_param(const nlohmann::json& data, const std::string& name, T& param) { + if (data.contains(name) && data[name].type() == json_type_traits<T>::json_value_t) { + param = data[name]; + } +} + +template <typename T> +void read_anymap_param(const ov::AnyMap& config_map, const std::string& name, T& param) { + if (config_map.count(name)) { + param = config_map.at(name).as<T>(); + } +} + } // namespace generate_utils -} // namespace ov \ No newline at end of file +} // namespace ov + diff --git a/src/python/openvino_genai/__init__.py b/src/python/openvino_genai/__init__.py index f604e03e84..e069157fa7 100644 --- a/src/python/openvino_genai/__init__.py +++ b/src/python/openvino_genai/__init__.py @@ -8,3 +8,7 @@ if hasattr(os, "add_dll_directory"): os.add_dll_directory(os.path.dirname(__file__)) + +from .py_generate_pipeline import LLMPipeline, Tokenizer, GenerationConfig, DecodedResults, EncodedResults + +__all__ = ['LLMPipeline', 'Tokenizer', 'GenerationConfig', 'DecodedResults', 'EncodedResults'] diff --git a/src/python/py_generate_pipeline.cpp b/src/python/py_generate_pipeline.cpp index 74cbe7e27d..efd286f4e0 100644 --- a/src/python/py_generate_pipeline.cpp +++ b/src/python/py_generate_pipeline.cpp @@ -62,15 +62,21 @@ std::string call_with_config(ov::LLMPipeline& pipe, const std::string& text, con return pipe(text, config); } +std::string genai_module_path() { + py::module_ m = py::module_::import("openvino_tokenizers"); + py::list path_list = m.attr("__path__"); + return std::string(py::str(path_list[0])) + "/lib"; +} + PYBIND11_MODULE(py_generate_pipeline, m) { m.doc() = "Pybind11 binding for LLM Pipeline"; - py::class_<LLMPipeline>(m, "LLMPipeline") - .def(py::init<const std::string, const ov::Tokenizer&, const std::string, const ov::AnyMap&>(), - py::arg("model_path"), py::arg("tokenizer"), py::arg("device") = "CPU", py::arg("plugin_config") = ov::AnyMap{}) - .def(py::init<std::string&, std::string, const ov::AnyMap&>(), - py::arg("path"), py::arg("device") = "CPU", py::arg("plugin_config") = ov::AnyMap{}) + .def(py::init<const std::string, const ov::Tokenizer&, const std::string, const ov::AnyMap&, const std::string&>(), + py::arg("model_path"), py::arg("tokenizer"), py::arg("device") = "CPU", + py::arg("plugin_config") = ov::AnyMap{}, py::arg("ov_tokenizer_path") = genai_module_path()) + .def(py::init<std::string&, std::string, const ov::AnyMap&, const std::string>(), + py::arg("path"), py::arg("device") = "CPU", py::arg("plugin_config") = ov::AnyMap{}, py::arg("ov_tokenizer_path") = genai_module_path()) .def("__call__", py::overload_cast<ov::LLMPipeline&, const std::string&, const py::kwargs&>(&call_with_kwargs)) .def("__call__", py::overload_cast<ov::LLMPipeline&, const std::string&, const ov::GenerationConfig&>(&call_with_config)) .def("generate", py::overload_cast<ov::LLMPipeline&, const std::string&, const py::kwargs&>(&call_with_kwargs)) @@ -96,7 +102,10 @@ PYBIND11_MODULE(py_generate_pipeline, m) { // Binding for Tokenizer py::class_<ov::Tokenizer>(m, "Tokenizer") .def(py::init<>()) - .def(py::init<std::string&, std::string>(), py::arg("tokenizers_path"), py::arg("device") = "CPU") + .def(py::init<std::string&, const std::string&, const std::string&>(), + py::arg("tokenizers_path"), + py::arg("device") = "CPU", + py::arg("ov_tokenizer_path") = py::str(genai_module_path())) // todo: implement encode/decode when for numpy inputs and outputs .def("encode", py::overload_cast<const std::string>(&ov::Tokenizer::encode), "Encode a single prompt") diff --git a/src/tests/python_tests/test_greedy.py b/src/tests/python_tests/test_greedy.py deleted file mode 100644 index f33909721b..0000000000 --- a/src/tests/python_tests/test_greedy.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (C) 2023-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -def test_tiny_llama(): - from transformers import AutoTokenizer, AutoModelForCausalLM - - tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") - model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") - - max_new_tokens = 32 - prompt = 'table is made of' - - encoded_prompt = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=True) - hf_encoded_output = model.generate(encoded_prompt, max_new_tokens=max_new_tokens, do_sample=False) - hf_output = tokenizer.decode(hf_encoded_output[0, encoded_prompt.shape[1]:]) - print(f'hf_output: {hf_output}') - - import sys - sys.path.append('src/python/openvino_genai/') - import py_generate_pipeline as genai - - pipe = genai.LLMPipeline('text_generation/causal_lm/TinyLlama-1.1B-Chat-v1.0/pytorch/dldt/FP16/') - ov_output = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=False) - print(f'ov_output: {ov_output}') - - assert hf_output == ov_output - -if __name__ == '__main__': - test_tiny_llama() diff --git a/tests/python_tests/list_test_models.py b/tests/python_tests/list_test_models.py new file mode 100644 index 0000000000..f0786bf48c --- /dev/null +++ b/tests/python_tests/list_test_models.py @@ -0,0 +1,23 @@ +# generate_models.py + +def models_list(): + model_ids = [ + ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", "TinyLlama-1.1B-Chat-v1.0"), + ("google/gemma-2b-it", "gemma-2b-it"), + ("google/gemma-7b-it", "gemma-7b-it"), + # ("meta-llama/Llama-2-7b-chat-hf", "Llama-2-7b-chat-hf"), + # ("meta-llama/Llama-2-13b-chat-hf", "Llama-2-13b-chat-hf"), + # ("openlm-research/open_llama_3b", "open_llama_3b"), + # ("openlm-research/open_llama_7b", "open_llama_7b"), + # ("databricks/dolly-v2-3b", "dolly-v2-3b"), + # ("databricks/dolly-v2-12b", "dolly-v2-12b"), + # ("mistralai/Mistral-7B-v0.1", "Mistral-7B-v0.1"), + # ("ikala/redpajama-3b-chat", "redpajama-3b-chat"), + # ("microsoft/phi-1_5", "phi-1_5/"), + # ("Qwen/Qwen1.5-7B-Chat", "Qwen1.5-7B-Chat"), + ] + return model_ids + +if __name__ == "__main__": + for model_id, model_path in models_list(): + print(model_id, model_path) diff --git a/tests/python_tests/requirements.txt b/tests/python_tests/requirements.txt new file mode 100644 index 0000000000..776f43a254 --- /dev/null +++ b/tests/python_tests/requirements.txt @@ -0,0 +1,3 @@ +pytest +transformers +torch \ No newline at end of file diff --git a/tests/python_tests/test_generate_api.py b/tests/python_tests/test_generate_api.py new file mode 100644 index 0000000000..9330e28d62 --- /dev/null +++ b/tests/python_tests/test_generate_api.py @@ -0,0 +1,110 @@ +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from list_test_models import models_list + + +@pytest.fixture(scope="module", params=models_list()) +def model_fixture(request): + model_id, path = request.param + from transformers import AutoTokenizer, AutoModelForCausalLM + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id) + return model_id, path, tokenizer, model + +def run_hf_ov_genai_comparison(model_fixture, generation_config, prompt): + model_id, path, tokenizer, model = model_fixture + + generation_config_hf = generation_config.copy() + # in OpenVINO GenAI this parameter is called stop_criteria, + # while in HF it's called early_stopping. + # HF values True, False and "never" correspond to OV GenAI values "early", "heuristic" and "never" + if generation_config_hf.get('stop_criteria'): + generation_config_hf['early_stopping'] = stop_criteria_map()[generation_config_hf.pop('stop_criteria')] + + encoded_prompt = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=True) + hf_encoded_output = model.generate(encoded_prompt, **generation_config_hf) + hf_output = tokenizer.decode(hf_encoded_output[0, encoded_prompt.shape[1]:]) + + import openvino_genai as ov_genai + pipe = ov_genai.LLMPipeline(path) + ov_output = pipe.generate(prompt, **generation_config) + + if hf_output != ov_output: + print(f'hf_output: {hf_output}') + print(f'ov_output: {ov_output}') + + assert hf_output == ov_output + + +def stop_criteria_map(): + return {"never": "never", "early": True, "heuristic": False} + +test_cases = [ + (dict(max_new_tokens=20, do_sample=False), 'table is made of'), # generation_config, prompt + (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'), + (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'Alan Turing was a'), + (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=30, diversity_penalty=1.0), 'Alan Turing was a'), + (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'), + (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'The Sun is yellow because'), + (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.5), 'The Sun is yellow because'), +] +@pytest.mark.parametrize("generation_config,prompt", test_cases) +def test_greedy_decoding(model_fixture, generation_config, prompt): + run_hf_ov_genai_comparison(model_fixture, generation_config, prompt) + + +prompts = ['The Sun is yellow because', 'Alan Turing was a', 'table is made of'] +@pytest.mark.parametrize("num_beam_groups", [2, 3, 8]) +@pytest.mark.parametrize("group_size", [5, 3, 10]) +@pytest.mark.parametrize("max_new_tokens", [20, 15]) +@pytest.mark.parametrize("diversity_penalty", [1.0, 1.5]) +@pytest.mark.parametrize("prompt", prompts) +def test_beam_search_decoding(model_fixture, num_beam_groups, group_size, + max_new_tokens, diversity_penalty, prompt): + generation_config = dict( + num_beam_groups=num_beam_groups, + num_beams=num_beam_groups * group_size, + diversity_penalty=diversity_penalty, + num_return_sequences=num_beam_groups * group_size, + max_new_tokens=max_new_tokens, + ) + run_hf_ov_genai_comparison(model_fixture, generation_config, prompt) + + +@pytest.mark.parametrize("stop_criteria", ["never", "early", "heuristic"]) +@pytest.mark.parametrize("prompt", prompts) +@pytest.mark.parametrize("max_new_tokens", [20, 40, 300]) +def test_stop_criteria(model_fixture, stop_criteria, prompt, max_new_tokens): + # todo: for long sentences early stop_criteria fails + if (stop_criteria == 'early' and max_new_tokens >= 300): + pytest.skip() + generation_config = dict( + num_beam_groups=2, + num_beams=2 * 3, + diversity_penalty=1.0, + num_return_sequences=2 * 3, + max_new_tokens=max_new_tokens, + stop_criteria=stop_criteria, + ) + run_hf_ov_genai_comparison(model_fixture, generation_config, prompt) + + +# test long sequences +@pytest.mark.parametrize("num_beam_groups", [2]) +@pytest.mark.parametrize("group_size", [5]) +@pytest.mark.parametrize("max_new_tokens", [800, 2000]) +@pytest.mark.parametrize("diversity_penalty", [1.0]) +@pytest.mark.parametrize("prompt", prompts) +@pytest.mark.skip # will be enabled in nightly since are computationally expensive +def test_beam_search_long_sentences(model_fixture, num_beam_groups, group_size, + max_new_tokens, diversity_penalty, prompt): + generation_config = dict( + num_beam_groups=num_beam_groups, + num_beams=num_beam_groups * group_size, + diversity_penalty=1.0, + num_return_sequences=num_beam_groups * group_size, + max_new_tokens=max_new_tokens, + ) + run_hf_ov_genai_comparison(model_fixture, generation_config, prompt) diff --git a/text_generation/causal_lm/cpp/CMakeLists.txt b/text_generation/causal_lm/cpp/CMakeLists.txt index 07d91e6d3b..1998c3ccb6 100644 --- a/text_generation/causal_lm/cpp/CMakeLists.txt +++ b/text_generation/causal_lm/cpp/CMakeLists.txt @@ -46,13 +46,7 @@ set_target_properties(prompt_lookup_decoding_lm PROPERTIES CXX_STANDARD_REQUIRED find_package(TBB REQUIRED COMPONENTS tbb) target_link_libraries(prompt_lookup_decoding_lm PRIVATE TBB::tbb) -add_executable(generate_sample generate_pipeline/generate_sample.cpp) -target_link_libraries(generate_sample PRIVATE openvino::genai) -target_include_directories(generate_sample PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}") -set_target_properties(generate_sample PROPERTIES CXX_STANDARD 17) -set_target_properties(generate_sample PROPERTIES CXX_STANDARD_REQUIRED ON) - -add_executable(chat_sample generate_pipeline/chat_sample.cpp) +add_executable(chat_sample chat_sample.cpp) target_link_libraries(chat_sample PRIVATE openvino::genai) target_include_directories(chat_sample PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}") set_target_properties(chat_sample PROPERTIES CXX_STANDARD 17) diff --git a/text_generation/causal_lm/cpp/generate_pipeline/chat_sample.cpp b/text_generation/causal_lm/cpp/chat_sample.cpp similarity index 100% rename from text_generation/causal_lm/cpp/generate_pipeline/chat_sample.cpp rename to text_generation/causal_lm/cpp/chat_sample.cpp diff --git a/text_generation/causal_lm/cpp/generate_pipeline/generate_sample.cpp b/text_generation/causal_lm/cpp/generate_pipeline/generate_sample.cpp deleted file mode 100644 index 84e07c394b..0000000000 --- a/text_generation/causal_lm/cpp/generate_pipeline/generate_sample.cpp +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include "openvino/genai/llm_pipeline.hpp" - -using std::cout; -using std::endl; - -int main(int argc, char* argv[]) { - if (2 > argc && argc > 4) - throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <MODEL_DIR> \"<PROMPT>\" <DEVICE>"); - std::string model_path = argv[1]; - - std::string prompt = "table is made of "; - std::string device = "CPU"; // can be replaced with GPU - - if (argc > 2) - prompt = argv[2]; - if (argc > 3) - device = argv[3]; - - // Example 1: Simplest example with greedy search - // Model, tokenizer and generation_config.json will be loaded from the model_path. - // If generation_config.json is not found default velues for gready search will be used - - // ov::streamer_lambda([](std::string subword){std::cout << subword << std::flush;}) - ov::LLMPipeline pipe(model_path, device); - // cout << prompt << pipe(prompt, ov::max_new_tokens(1000)) << endl; - - // todo: syntactic sugar to specify generation configs in place - // cout << prompt << pipe(prompt, ov::max_new_tokens(100)) << endl; - - - auto tokenizer = ov::Tokenizer(model_path); - auto [input_ids, attention_mask] = tokenizer.encode("table is made of "); - auto resuling_tokens = pipe.generate(input_ids, ov::max_new_tokens(1000)); - cout << tokenizer.decode(resuling_tokens.tokens[0]) << endl; - - // Example 2: Modifying generation_cofnig to use grouped beam search - ov::GenerationConfig config = pipe.get_generation_config(); - config.max_new_tokens = 100; - config.num_beams = 15; - config.num_beam_groups = 3; - // cout << prompt << pipe(prompt, config) << endl; - - // cout << endl << "grouped beam search generated candidates:" << endl; - // for (int i = 0; i < num_return_sequences; ++i) - // will return vector with num_return_sequences strings - // auto num_return_sequences = 3; - - // // Example 3: Greedy Decoding with multiple batch - // pipe = ov::LLMPipeline(model_path, device); - // config = pipe.generation_config(); - - // cout << endl << "greedy decoding with multiple batches:" << endl; - // std::vector<std::string> prompts = {"table is made of", "Alan Turing was a", "1 + 1 = ", "Why is the Sun yellow?"}; - // auto results = pipe(prompts, config.max_new_tokens(20)); - // for (const auto& res: results) - // std::cout << res.text << std::endl; - - // // Example 4: Calling tokenizer/detokenizer manually and getting beam scores for all candidates - // pipe = ov::LLMPipeline(model_path); - // auto [input_ids, attention_mask] = pipe.get_tokenizer().tokenize({prompt}); - // config = GenerationConfig::beam_search(); - // // config for grouped beam search - // config.max_new_tokens(30).num_groups(3).group_size(5).num_return_sequences(15); - - // cout << endl << "beam search with printing of all candidates:" << endl; - // auto beams = pipe.generate(input_ids, attention_mask, config); - // for (size_t i = 0; i < beams.scores.size(); i++) { - // std::cout << beams.scores[i] << ": " << pipe.get_tokenizer().detokenize(beams.tokens[i]) << std::endl; - // } - - // // for (const auto& beam : beams.second) - // // std::cout << beam.first << ": " << pipe.detokenize(beam.second) << std::endl; - - // { - // // Example 5: Speculative sampling - // std::string assitive_model_path = "text_generation/causal_lm/TinyLlama-1.1B-Chat-v1.0/pytorch/dldt/FP16"; - // pipe = ov::LLMPipeline(model_path); - // auto [input_ids, attention_mask] = pipe.get_tokenizer().tokenize({prompt}); - // // config = GenerationConfig::assistive_decoding(assitive_model_path).num_assistant_tokens(5).max_new_tokens(20); - // pipe.generation_config().assistant_model(assitive_model_path); - - // cout << endl << "Speculative sampling with TinyLlama assistance:" << endl; - // auto results = pipe.generate(input_ids, attention_mask, config); - // for (size_t i = 0; i < beams.scores.size(); i++) { - // for (const auto& result : results) - // std::cout << pipe.get_tokenizer().detokenize(result.tokens) << std::endl; - // } - // } - - return 0; -} diff --git a/text_generation/causal_lm/cpp/greedy_causal_lm.cpp b/text_generation/causal_lm/cpp/greedy_causal_lm.cpp index 7b1dde4dc8..e410d170ca 100644 --- a/text_generation/causal_lm/cpp/greedy_causal_lm.cpp +++ b/text_generation/causal_lm/cpp/greedy_causal_lm.cpp @@ -17,6 +17,7 @@ int main(int argc, char* argv[]) try { ov::LLMPipeline pipe(model_path, device); ov::GenerationConfig config = pipe.get_generation_config(); config.max_new_tokens = 100; + config.do_sample = false; auto streamer = [](std::string subword){std::cout << subword << std::flush;}; // since streamer is set results will be printed each time a new token is generated From ce4eb00ff94d53f5e1840bd2e3320356c9a86616 Mon Sep 17 00:00:00 2001 From: Pavel Esir <pavel.esir@gmail.com> Date: Wed, 22 May 2024 11:35:55 +0200 Subject: [PATCH 02/40] Apply suggestions from code review Co-authored-by: Zlobin Vladimir <vladimir.zlobin@intel.com> Co-authored-by: Alexander Suvorov <alexander.suvorov@intel.com> --- CMakeLists.txt | 10 ++++++++-- .../openvino/genai/generation_config.hpp | 17 +++++++++-------- src/cpp/include/openvino/genai/llm_pipeline.hpp | 6 +++--- src/cpp/src/generation_config.cpp | 1 - src/cpp/src/group_beam_searcher.hpp | 2 +- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 809327095c..ac392233a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,8 +4,14 @@ cmake_minimum_required(VERSION 3.15) -set(CMAKE_BUILD_TYPE "Release" CACHE STRING "CMake build type") -set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Release" "Debug" "RelWithDebInfo" "MinSizeRel") +# Multi config generators such as Visual Studio ignore CMAKE_BUILD_TYPE. Multi config generators are configured with +# CMAKE_CONFIGURATION_TYPES, but limiting options in it completely removes such build options +get_property(GENERATOR_IS_MULTI_CONFIG_VAR GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG) +if(NOT GENERATOR_IS_MULTI_CONFIG_VAR AND NOT DEFINED CMAKE_BUILD_TYPE) + message(STATUS "CMAKE_BUILD_TYPE is not defined, 'Release' will be used") + # Setting CMAKE_BUILD_TYPE as CACHE must go before project(). Otherwise project() sets its value and set() doesn't take an effect + set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel ...") +endif() project(openvino_genai VERSION 2024.2.0.0) diff --git a/src/cpp/include/openvino/genai/generation_config.hpp b/src/cpp/include/openvino/genai/generation_config.hpp index e1f2151d49..837fae21ad 100644 --- a/src/cpp/include/openvino/genai/generation_config.hpp +++ b/src/cpp/include/openvino/genai/generation_config.hpp @@ -14,9 +14,10 @@ namespace ov { /** - * @brief controls the stopping condition for grouped beam search. The following values are possible: - * "early", where the generation stops as soon as there are `num_beams` complete candidates; "heuristic", where an - * heuristic is applied and the generation stops when is it very unlikely to find better candidates; + * @brief controls the stopping condition for grouped beam search. The following values are possible: + * "early" stops as soon as there are `num_beams` complete candidates. + "heuristic" stops when is it unlikely to find better candidates. + "never" stops when there cannot be better candidates. */ enum class StopCriteria { early, heuristic, never }; @@ -25,11 +26,11 @@ enum class StopCriteria { early, heuristic, never }; * * @param max_length the maximum length the generated tokens can have. Corresponds to the length of the input prompt + * `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. - * @param max_new_tokens the maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. + * @param max_new_tokens the maximum numbers of tokens to generate, excluding the number of tokens in the prompt. max_new_tokens has priority over max_length. * @param ignore_eos if set to true, then generation will not stop even if <eos> token is met. - * @param num_beams number of beams for beam search. 1 means no beam search. + * @param num_beams number of beams for beam search. 1 disables beam search. * @param num_beam_groups number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. - * @param diversity_penalty this value is subtracted from a beam's score if it generates a token same as any beam from other group at a + * @param diversity_penalty this value is subtracted from a beam's score if it generates the same token as any beam from other group at a * particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled. * @param length_penalty exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to * the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log @@ -42,11 +43,11 @@ enum class StopCriteria { early, heuristic, never }; * heuristic is applied and the generation stops when is it very unlikely to find better candidates; * "never", where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm). * @param temperature the value used to modulate token probabilities for random sampling - * @param top_p if set to float < 1, only the smallest set of most probable tokens with probabilities + * @param top_p - if set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. * @param top_k the number of highest probability vocabulary tokens to keep for top-k-filtering. * @param do_sample whether or not to use multinomial random sampling * that add up to `top_p` or higher are kept. - * @param repetition_penalty the parameter for repetition penalty. 1.0 means no penalty. + * @param repetition_penalty the parameter for repetition penalty. 1.0 means no penalty. See https://arxiv.org/pdf/1909.05858. * @param pad_token_id id of padding token * @param bos_token_id id of <bos> token * @param eos_token_id id of <eos> token diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index 2a6e53eea6..1345b488f4 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -65,7 +65,7 @@ class DecodedResults { class OPENVINO_GENAI_EXPORTS LLMPipeline { public: /** - * @brief Constructs a LLMPipeline when convert model xml/bin files, tokenizers and configuration and in the same dir. + * @brief Constructs an LLMPipeline from xml/bin files, tokenizers and configuration in the same dir. * * @param model_path Path to the dir model xml/bin files, tokenizers and generation_configs.json * @param device optional device @@ -105,8 +105,8 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline { template <typename... Properties> util::EnableIfAllStringAny<std::string, Properties...> generate( - std::string text, - Properties&&... properties) { + std::string text, + Properties&&... properties) { return generate(text, AnyMap{std::forward<Properties>(properties)...}); } std::string generate(std::string text, const ov::AnyMap& config); diff --git a/src/cpp/src/generation_config.cpp b/src/cpp/src/generation_config.cpp index f7cbcaa075..14fc370c59 100644 --- a/src/cpp/src/generation_config.cpp +++ b/src/cpp/src/generation_config.cpp @@ -37,7 +37,6 @@ GenerationConfig::GenerationConfig(std::string json_path) { read_json_param(data, "length_penalty", length_penalty); read_json_param(data, "num_return_sequences", num_return_sequences); read_json_param(data, "no_repeat_ngram_size", no_repeat_ngram_size); - // stop_criteria will be processed below read_json_param(data, "temperature", temperature); read_json_param(data, "top_p", top_p); read_json_param(data, "top_k", top_k); diff --git a/src/cpp/src/group_beam_searcher.hpp b/src/cpp/src/group_beam_searcher.hpp index 91f3ef4096..5362c9cfae 100644 --- a/src/cpp/src/group_beam_searcher.hpp +++ b/src/cpp/src/group_beam_searcher.hpp @@ -8,5 +8,5 @@ #include "openvino/genai/llm_pipeline.hpp" namespace ov { - EncodedResults beam_search(ov::InferRequest& lm, ov::Tensor prompts, ov::Tensor attentin_mask, GenerationConfig sampling_params); + EncodedResults beam_search(ov::InferRequest& lm, ov::Tensor prompts, ov::Tensor attentin_mask, GenerationConfig config); } From aa90e9d229cc2357acee7e1d202c1a7d5871a63b Mon Sep 17 00:00:00 2001 From: Pavel Esir <pavel.esir@intel.com> Date: Wed, 22 May 2024 12:09:22 +0200 Subject: [PATCH 03/40] names correction --- .github/workflows/genai_python_lib.yml | 10 ++++++---- src/cpp/include/openvino/genai/llm_pipeline.hpp | 4 ++-- src/cpp/include/openvino/genai/tokenizer.hpp | 2 +- src/cpp/src/llm_pipeline.cpp | 14 +++++++------- src/cpp/src/tokenizer.cpp | 10 +++++----- src/python/py_generate_pipeline.cpp | 10 +++++----- 6 files changed, 26 insertions(+), 24 deletions(-) diff --git a/.github/workflows/genai_python_lib.yml b/.github/workflows/genai_python_lib.yml index 6697ba934a..58e78d3b36 100644 --- a/.github/workflows/genai_python_lib.yml +++ b/.github/workflows/genai_python_lib.yml @@ -16,11 +16,13 @@ jobs: - run: source ./ov/setupvars.sh && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - run: source ./ov/setupvars.sh && cmake --build ./build/ --config Release -j - run: python -m pip install --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly # Can't load CentOS libraries from the archive - - run: PYTHONPATH=./src/python/ python -c "from openvino_genai.py_generate_pipeline import LLMPipeline" + - run: PYTHONPATH=./src/python/ python -c "from openvino_genai import LLMPipeline" - run: source ./ov/setupvars.sh && python -m pip install --pre . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly + - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: python -c "from openvino_genai import LLMPipeline" - - name: Install optimum-cli and run for each model + - name: GenAI Python API tests run: | + source ./ov/setupvars.sh cd ./tests/ python -m pip install -r requirements.txt models=$(python3 generate_models.py) @@ -47,6 +49,6 @@ jobs: - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --build ./build/ --config Release -j - run: python -m pip install "numpy<1.27" - - run: set "PYTHONPATH=./src/python;" && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -c "from openvino_genai.py_generate_pipeline import LLMPipeline" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. + - run: set "PYTHONPATH=./src/python;" && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -c "from openvino_genai import LLMPipeline" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install . - - run: python -c "from openvino_genai.py_generate_pipeline import LLMPipeline" + - run: python -c "from openvino_genai import LLMPipeline" diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index 1345b488f4..3bc8453d4e 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -73,7 +73,7 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline { */ LLMPipeline(std::string& path, std::string device="CPU", const ov::AnyMap& plugin_config={}, - const std::string& ov_tokenizer_path=""); + const std::string& ov_tokenizers_path=""); /** * @brief Constructs a LLMPipeline when ov::Tokenizer is initialized manually using file from the different dirs. @@ -88,7 +88,7 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline { const ov::Tokenizer& tokenizer, const std::string device="CPU", const ov::AnyMap& plugin_config = {}, - const std::string& ov_tokenizer_path="" + const std::string& ov_tokenizers_path="" ); ~LLMPipeline(); diff --git a/src/cpp/include/openvino/genai/tokenizer.hpp b/src/cpp/include/openvino/genai/tokenizer.hpp index 07bfe96d44..03c0cd64f7 100644 --- a/src/cpp/include/openvino/genai/tokenizer.hpp +++ b/src/cpp/include/openvino/genai/tokenizer.hpp @@ -21,7 +21,7 @@ class OPENVINO_GENAI_EXPORTS Tokenizer { * @param tokenizer_path openvino_tokenizer.xml and openvino_detokenizer.xml should be located in the tokenizer_path * @param device device. Currently only 'CPU' is supported */ - Tokenizer(const std::string& tokenizers_path, const std::string& device="CPU", const std::string& ov_tokenizer_path=""); + Tokenizer(const std::string& tokenizers_path, const std::string& device="CPU", const std::string& ov_tokenizers_path=""); /** * @brief encode a single prompt diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index c47e54dd48..7485998ab0 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -45,11 +45,11 @@ class LLMPipeline::LLMPipelineImpl { const ov::Tokenizer& tokenizer, const std::string device, const ov::AnyMap& plugin_config, - const std::string& ov_tokenizer_path="" + const std::string& ov_tokenizers_path="" ); LLMPipelineImpl(std::string& path, std::string device, const ov::AnyMap& config); - LLMPipelineImpl(std::string& path, std::string device, const ov::AnyMap& config, const std::string& ov_tokenizer_path=""); + LLMPipelineImpl(std::string& path, std::string device, const ov::AnyMap& config, const std::string& ov_tokenizers_path=""); GenerationConfig generation_config() const; @@ -71,7 +71,7 @@ ov::LLMPipeline::LLMPipeline( const ov::Tokenizer& tokenizer, const std::string device, const ov::AnyMap& plugin_config, - const std::string& ov_tokenizer_path + const std::string& ov_tokenizers_path ) { m_pimpl = make_unique<LLMPipelineImpl>(model_path, tokenizer, device, plugin_config); } @@ -81,7 +81,7 @@ ov::LLMPipeline::LLMPipelineImpl::LLMPipelineImpl( const ov::Tokenizer& tokenizer, std::string device, const ov::AnyMap& plugin_config, - const std::string& ov_tokenizer_path + const std::string& ov_tokenizers_path ): m_tokenizer(tokenizer), m_device(device), m_plugin_config(plugin_config) { ov::Core core; @@ -95,12 +95,12 @@ ov::LLMPipeline::LLMPipelineImpl::LLMPipelineImpl( } } -ov::LLMPipeline::LLMPipeline(std::string& path, std::string device, const ov::AnyMap& config, const std::string& ov_tokenizer_path) { - m_pimpl = make_unique<LLMPipelineImpl>(path, device, config, ov_tokenizer_path); +ov::LLMPipeline::LLMPipeline(std::string& path, std::string device, const ov::AnyMap& config, const std::string& ov_tokenizers_path) { + m_pimpl = make_unique<LLMPipelineImpl>(path, device, config, ov_tokenizers_path); } ov::LLMPipeline::LLMPipelineImpl::LLMPipelineImpl(std::string& path, std::string device, - const ov::AnyMap& config, const std::string& ov_tokenizer_path) { + const ov::AnyMap& config, const std::string& ov_tokenizers_path) { std::string config_path = path + "/" + "config.json"; std::string tokenizer_config_path = path + "/" +"tokenizer_config.json"; std::string generation_config_path = path + "/" +"generation_config.json"; diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp index a7df42e2e0..778778faec 100644 --- a/src/cpp/src/tokenizer.cpp +++ b/src/cpp/src/tokenizer.cpp @@ -53,17 +53,17 @@ class Tokenizer::TokenizerImpl { int64_t m_eos_token_id = 2; TokenizerImpl() = default; - TokenizerImpl(std::string tokenizers_path, const std::string device, const std::string& ov_tokenizer_path) { + TokenizerImpl(std::string tokenizers_path, const std::string device, const std::string& ov_tokenizers_path) { ov::Core core; if (ov::generate_utils::is_xml(tokenizers_path)) OPENVINO_THROW("tokenizers_path should be a path to a dir not a xml file"); - if (ov_tokenizer_path.empty()) { + if (ov_tokenizers_path.empty()) { // OPENVINO_TOKENIZERS_PATH is defined in CMakeLists.txt core.add_extension(OPENVINO_TOKENIZERS_PATH); } else { - core.add_extension(ov_tokenizer_path + "/libopenvino_tokenizers.so"); + core.add_extension(ov_tokenizers_path + "/libopenvino_tokenizers.so"); } std::shared_ptr<ov::Model> tokenizer_model, detokenizer_model; try { @@ -144,8 +144,8 @@ class Tokenizer::TokenizerImpl { } }; -Tokenizer::Tokenizer(const std::string& tokenizers_path, const std::string& device, const std::string& ov_tokenizer_path) { - m_pimpl = std::make_shared<TokenizerImpl>(tokenizers_path, device, ov_tokenizer_path); +Tokenizer::Tokenizer(const std::string& tokenizers_path, const std::string& device, const std::string& ov_tokenizers_path) { + m_pimpl = std::make_shared<TokenizerImpl>(tokenizers_path, device, ov_tokenizers_path); } std::pair<ov::Tensor, ov::Tensor> Tokenizer::encode(const std::string prompt) { diff --git a/src/python/py_generate_pipeline.cpp b/src/python/py_generate_pipeline.cpp index efd286f4e0..2aee67593c 100644 --- a/src/python/py_generate_pipeline.cpp +++ b/src/python/py_generate_pipeline.cpp @@ -62,9 +62,9 @@ std::string call_with_config(ov::LLMPipeline& pipe, const std::string& text, con return pipe(text, config); } -std::string genai_module_path() { +std::string ov_tokenizers_module_path() { py::module_ m = py::module_::import("openvino_tokenizers"); - py::list path_list = m.attr("__path__"); + py::list path_list = m.attr("__path__"); return std::string(py::str(path_list[0])) + "/lib"; } @@ -74,9 +74,9 @@ PYBIND11_MODULE(py_generate_pipeline, m) { py::class_<LLMPipeline>(m, "LLMPipeline") .def(py::init<const std::string, const ov::Tokenizer&, const std::string, const ov::AnyMap&, const std::string&>(), py::arg("model_path"), py::arg("tokenizer"), py::arg("device") = "CPU", - py::arg("plugin_config") = ov::AnyMap{}, py::arg("ov_tokenizer_path") = genai_module_path()) + py::arg("plugin_config") = ov::AnyMap{}, py::arg("ov_tokenizers_path") = ov_tokenizers_module_path()) .def(py::init<std::string&, std::string, const ov::AnyMap&, const std::string>(), - py::arg("path"), py::arg("device") = "CPU", py::arg("plugin_config") = ov::AnyMap{}, py::arg("ov_tokenizer_path") = genai_module_path()) + py::arg("path"), py::arg("device") = "CPU", py::arg("plugin_config") = ov::AnyMap{}, py::arg("ov_tokenizers_path") = ov_tokenizers_module_path()) .def("__call__", py::overload_cast<ov::LLMPipeline&, const std::string&, const py::kwargs&>(&call_with_kwargs)) .def("__call__", py::overload_cast<ov::LLMPipeline&, const std::string&, const ov::GenerationConfig&>(&call_with_config)) .def("generate", py::overload_cast<ov::LLMPipeline&, const std::string&, const py::kwargs&>(&call_with_kwargs)) @@ -105,7 +105,7 @@ PYBIND11_MODULE(py_generate_pipeline, m) { .def(py::init<std::string&, const std::string&, const std::string&>(), py::arg("tokenizers_path"), py::arg("device") = "CPU", - py::arg("ov_tokenizer_path") = py::str(genai_module_path())) + py::arg("ov_tokenizers_path") = py::str(ov_tokenizers_module_path())) // todo: implement encode/decode when for numpy inputs and outputs .def("encode", py::overload_cast<const std::string>(&ov::Tokenizer::encode), "Encode a single prompt") From d843229835d5bc29d33ed730ee5d2a863249cb81 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 14:56:18 +0400 Subject: [PATCH 04/40] enable --- .github/workflows/genai_package.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/genai_package.yml b/.github/workflows/genai_package.yml index b6f1647c7a..3f3ee5082e 100644 --- a/.github/workflows/genai_package.yml +++ b/.github/workflows/genai_package.yml @@ -2,7 +2,6 @@ name: genai_package on: pull_request jobs: ubuntu_genai_package: - if: false runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v4 From 2c1d1ef7996fe6e1844fb2acdeec6921121577ee Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 15:54:38 +0400 Subject: [PATCH 05/40] libtbb-dev --- .github/workflows/genai_package.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/genai_package.yml b/.github/workflows/genai_package.yml index 3f3ee5082e..ba6e223b92 100644 --- a/.github/workflows/genai_package.yml +++ b/.github/workflows/genai_package.yml @@ -16,6 +16,7 @@ jobs: - run: source ./ov/setupvars.sh && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - run: source ./ov/setupvars.sh && cmake --build ./build/ --config Release --target package -j - run: source ./ov/setupvars.sh && cmake --install ./build/ --config Release --prefix ov + - run: sudo apt-get install libtbb-dev - run: ov/samples/cpp/build_samples.sh -b "${{ github.workspace }}/s pace" - run: source ./ov/setupvars.sh && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] From 57ca2d48f682b4671fdc10d0604a82117b41cd19 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 15:57:21 +0400 Subject: [PATCH 06/40] move --- .github/workflows/genai_package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/genai_package.yml b/.github/workflows/genai_package.yml index ba6e223b92..6b340c8d2f 100644 --- a/.github/workflows/genai_package.yml +++ b/.github/workflows/genai_package.yml @@ -13,10 +13,10 @@ jobs: - run: mkdir ./ov/ - run: curl https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.1/linux/l_openvino_toolkit_ubuntu20_2024.1.0.15008.f4afc983258_x86_64.tgz | tar --directory ./ov/ --strip-components 1 -xz - run: sudo ./ov/install_dependencies/install_openvino_dependencies.sh + - run: sudo apt-get install libtbb-dev - run: source ./ov/setupvars.sh && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - run: source ./ov/setupvars.sh && cmake --build ./build/ --config Release --target package -j - run: source ./ov/setupvars.sh && cmake --install ./build/ --config Release --prefix ov - - run: sudo apt-get install libtbb-dev - run: ov/samples/cpp/build_samples.sh -b "${{ github.workspace }}/s pace" - run: source ./ov/setupvars.sh && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] From 37844c90e63f2ede84c8a39f3293ef17659eb0be Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 17:38:57 +0400 Subject: [PATCH 07/40] slash --- .github/workflows/genai_package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/genai_package.yml b/.github/workflows/genai_package.yml index 6b340c8d2f..65d3e6da3e 100644 --- a/.github/workflows/genai_package.yml +++ b/.github/workflows/genai_package.yml @@ -21,7 +21,7 @@ jobs: - run: source ./ov/setupvars.sh && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: source ./ov/setupvars.sh && optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 - - run: source ./ov/setupvars.sh && timeout 50s "${{ github.workspace }}/s pace/intel64/Release/greedy_causal_lm" .\TinyLlama-1.1B-Chat-v1.0\ "" + - run: source ./ov/setupvars.sh && timeout 50s "${{ github.workspace }}/s pace/intel64/Release/greedy_causal_lm" ./TinyLlama-1.1B-Chat-v1.0/ "" windows_genai_package: runs-on: windows-latest From 5cff21e13bb17a4fdccd0641c79bd29fbce6096d Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 18:24:33 +0400 Subject: [PATCH 08/40] install --- src/cpp/CMakeLists.txt | 13 +++++++++---- src/python/CMakeLists.txt | 1 + 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index ffe28a81df..675edccd44 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -66,7 +66,12 @@ add_custom_command(TARGET ${TARGET_NAME} POST_BUILD "${CMAKE_CURRENT_SOURCE_DIR}/../python/openvino_genai/$<TARGET_FILE_NAME:${TARGET_NAME}>" COMMENT "Copy ${TARGET_NAME} to src/python/openvino_genai") -install(TARGETS ${TARGET_NAME} LIBRARY DESTINATION . COMPONENT core_genai RUNTIME DESTINATION . COMPONENT core_genai) +find_package(Python3 REQUIRED COMPONENTS Interpreter Development) +install(TARGETS ${TARGET_NAME} + LIBRARY DESTINATION python/openvino_genai/ COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR} + RUNTIME DESTINATION python/openvino_genai/ COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR}) + +install(TARGETS ${TARGET_NAME} LIBRARY DESTINATION . COMPONENT pygenai RUNTIME DESTINATION . COMPONENT pygenai) # - Windows: `<openvino_dir>\runtime\bin\intel64\Release\` # - MacOS_x86: `<openvino_dir>/runtime/lib/intel64/Release` @@ -90,9 +95,10 @@ if(MSVC OR APPLE) set(ARCH_DIR ${ARCH_DIR}/${CMAKE_BUILD_TYPE}) endif() install(TARGETS ${TARGET_NAME} EXPORT openvino_genaiTargets - LIBRARY DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai_dev + LIBRARY DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai + RUNTIME DESTINATION runtime/bin/${ARCH_DIR} COMPONENT core_genai) +install(TARGETS ${TARGET_NAME} EXPORT openvino_genaiTargets ARCHIVE DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai_dev - RUNTIME DESTINATION runtime/bin/${ARCH_DIR} COMPONENT core_genai_dev INCLUDES DESTINATION runtime/include) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ DESTINATION runtime/include COMPONENT core_genai_dev) install(EXPORT openvino_genaiTargets FILE openvino_genaiTargets.cmake NAMESPACE openvino:: DESTINATION runtime/cmake) @@ -102,4 +108,3 @@ install(FILES "${CMAKE_BINARY_DIR}/openvino_genaiConfig.cmake" "${CMAKE_BINARY_D include(CMakePackageConfigHelpers) write_basic_package_version_file("${CMAKE_BINARY_DIR}/openvino_genaiConfigVersion.cmake" VERSION ${CMAKE_PROJECT_VERSION} COMPATIBILITY AnyNewerVersion) export(EXPORT openvino_genaiTargets FILE "${CMAKE_BINARY_DIR}/openvino_genaiTargets.cmake" NAMESPACE openvino::) -# export(TARGETS ${TARGET_NAME} NAMESPACE openvino:: FILE "${CMAKE_BINARY_DIR}/openvino_genaiConfig.cmake") TODO diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 62f26f3215..0ebdc78cd9 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -46,3 +46,4 @@ add_custom_command(TARGET py_generate_pipeline POST_BUILD find_package(Python3 REQUIRED COMPONENTS Interpreter Development) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/openvino_genai/ DESTINATION python/openvino_genai/ COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR}) +install(TARGETS py_generate_pipeline LIBRARY DESTINATION python/openvino_genai/ COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR}) From 561b55a15794b9ee7c51397dcd3f987bb33968fd Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 18:29:58 +0400 Subject: [PATCH 09/40] core_genai_dev --- src/cpp/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index 675edccd44..c6b7fd7a1a 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -98,8 +98,10 @@ install(TARGETS ${TARGET_NAME} EXPORT openvino_genaiTargets LIBRARY DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai RUNTIME DESTINATION runtime/bin/${ARCH_DIR} COMPONENT core_genai) install(TARGETS ${TARGET_NAME} EXPORT openvino_genaiTargets + LIBRARY DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai_dev ARCHIVE DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai_dev - INCLUDES DESTINATION runtime/include) + RUNTIME DESTINATION runtime/bin/${ARCH_DIR} COMPONENT core_genai_dev + INCLUDES DESTINATION runtime/include COMPONENT core_genai_dev) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ DESTINATION runtime/include COMPONENT core_genai_dev) install(EXPORT openvino_genaiTargets FILE openvino_genaiTargets.cmake NAMESPACE openvino:: DESTINATION runtime/cmake) include(CMakePackageConfigHelpers) From 260d913940235ba567eeafc28c551a91d3c6cc06 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 18:31:55 +0400 Subject: [PATCH 10/40] remove export --- src/cpp/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index c6b7fd7a1a..e43e2e9c1d 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -94,7 +94,7 @@ endif() if(MSVC OR APPLE) set(ARCH_DIR ${ARCH_DIR}/${CMAKE_BUILD_TYPE}) endif() -install(TARGETS ${TARGET_NAME} EXPORT openvino_genaiTargets +install(TARGETS ${TARGET_NAME} LIBRARY DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai RUNTIME DESTINATION runtime/bin/${ARCH_DIR} COMPONENT core_genai) install(TARGETS ${TARGET_NAME} EXPORT openvino_genaiTargets From 5a0079bba7dc00095e0940f0183d48ed3be7cc57 Mon Sep 17 00:00:00 2001 From: Pavel Esir <pavel.esir@intel.com> Date: Wed, 22 May 2024 13:38:26 +0200 Subject: [PATCH 11/40] install openvino_tokenizers for genai_python_lib --- .github/workflows/genai_python_lib.yml | 12 ++++++------ src/cpp/src/llm_pipeline.cpp | 5 ++--- tests/python_tests/list_test_models.py | 6 ++---- tests/python_tests/requirements.txt | 3 ++- tests/python_tests/test_generate_api.py | 19 ++++++++++++------- 5 files changed, 24 insertions(+), 21 deletions(-) diff --git a/.github/workflows/genai_python_lib.yml b/.github/workflows/genai_python_lib.yml index 58e78d3b36..29f537858a 100644 --- a/.github/workflows/genai_python_lib.yml +++ b/.github/workflows/genai_python_lib.yml @@ -2,7 +2,7 @@ name: genai_python_lib on: pull_request jobs: ubuntu_genai_python_lib: - runs-on: ubuntu-20.04 + runs-on: ubuntu-20.04-16-cores steps: - uses: actions/checkout@v4 with: @@ -16,18 +16,17 @@ jobs: - run: source ./ov/setupvars.sh && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - run: source ./ov/setupvars.sh && cmake --build ./build/ --config Release -j - run: python -m pip install --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly # Can't load CentOS libraries from the archive + - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: PYTHONPATH=./src/python/ python -c "from openvino_genai import LLMPipeline" - run: source ./ov/setupvars.sh && python -m pip install --pre . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly - - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: python -c "from openvino_genai import LLMPipeline" - name: GenAI Python API tests run: | source ./ov/setupvars.sh - cd ./tests/ + cd ./tests/python_tests/ python -m pip install -r requirements.txt - models=$(python3 generate_models.py) + models=$(python list_test_models.py) echo "$models" | while read -r model_name model_path; do - echo "Processing model: $model_name at $model_path" optimum-cli export openvino --trust-remote-code --weight-format fp16 --model "$model_name" "$model_path" done python -m pytest test_generate_api.py @@ -49,6 +48,7 @@ jobs: - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --build ./build/ --config Release -j - run: python -m pip install "numpy<1.27" + - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: set "PYTHONPATH=./src/python;" && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -c "from openvino_genai import LLMPipeline" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. - - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install . + - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install . - run: python -c "from openvino_genai import LLMPipeline" diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 7485998ab0..4415e507fe 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -48,7 +48,6 @@ class LLMPipeline::LLMPipelineImpl { const std::string& ov_tokenizers_path="" ); - LLMPipelineImpl(std::string& path, std::string device, const ov::AnyMap& config); LLMPipelineImpl(std::string& path, std::string device, const ov::AnyMap& config, const std::string& ov_tokenizers_path=""); GenerationConfig generation_config() const; @@ -73,7 +72,7 @@ ov::LLMPipeline::LLMPipeline( const ov::AnyMap& plugin_config, const std::string& ov_tokenizers_path ) { - m_pimpl = make_unique<LLMPipelineImpl>(model_path, tokenizer, device, plugin_config); + m_pimpl = make_unique<LLMPipelineImpl>(model_path, tokenizer, device, plugin_config, ov_tokenizers_path); } ov::LLMPipeline::LLMPipelineImpl::LLMPipelineImpl( @@ -130,7 +129,7 @@ ov::LLMPipeline::LLMPipelineImpl::LLMPipelineImpl(std::string& path, std::string ov::Core core; m_model_runner = core.compile_model(path + "/openvino_model.xml", device, config).create_infer_request(); - m_tokenizer = Tokenizer(path); + m_tokenizer = Tokenizer(path, device, ov_tokenizers_path); } ov::GenerationConfig ov::LLMPipeline::LLMPipelineImpl::generation_config() const { diff --git a/tests/python_tests/list_test_models.py b/tests/python_tests/list_test_models.py index f0786bf48c..09addcfaba 100644 --- a/tests/python_tests/list_test_models.py +++ b/tests/python_tests/list_test_models.py @@ -1,10 +1,8 @@ -# generate_models.py - def models_list(): model_ids = [ ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", "TinyLlama-1.1B-Chat-v1.0"), - ("google/gemma-2b-it", "gemma-2b-it"), - ("google/gemma-7b-it", "gemma-7b-it"), + # ("google/gemma-2b-it", "gemma-2b-it"), + # ("google/gemma-7b-it", "gemma-7b-it"), # ("meta-llama/Llama-2-7b-chat-hf", "Llama-2-7b-chat-hf"), # ("meta-llama/Llama-2-13b-chat-hf", "Llama-2-13b-chat-hf"), # ("openlm-research/open_llama_3b", "open_llama_3b"), diff --git a/tests/python_tests/requirements.txt b/tests/python_tests/requirements.txt index 776f43a254..e536fd531e 100644 --- a/tests/python_tests/requirements.txt +++ b/tests/python_tests/requirements.txt @@ -1,3 +1,4 @@ pytest transformers -torch \ No newline at end of file +torch +optimum-intel[openvino] @ git+https://github.com/huggingface/optimum-intel.git@fb1b35bef23242d65b2fb057c4a7ac78a7cfd4c3 \ No newline at end of file diff --git a/tests/python_tests/test_generate_api.py b/tests/python_tests/test_generate_api.py index 9330e28d62..d0bec03107 100644 --- a/tests/python_tests/test_generate_api.py +++ b/tests/python_tests/test_generate_api.py @@ -27,8 +27,11 @@ def run_hf_ov_genai_comparison(model_fixture, generation_config, prompt): hf_encoded_output = model.generate(encoded_prompt, **generation_config_hf) hf_output = tokenizer.decode(hf_encoded_output[0, encoded_prompt.shape[1]:]) + device = 'CPU' + ov_tokenziers_path = '../../build/openvino_tokenizers/src/' import openvino_genai as ov_genai - pipe = ov_genai.LLMPipeline(path) + + pipe = ov_genai.LLMPipeline(path, device, {}, ov_tokenziers_path) ov_output = pipe.generate(prompt, **generation_config) if hf_output != ov_output: @@ -43,12 +46,12 @@ def stop_criteria_map(): test_cases = [ (dict(max_new_tokens=20, do_sample=False), 'table is made of'), # generation_config, prompt - (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'), - (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'Alan Turing was a'), - (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=30, diversity_penalty=1.0), 'Alan Turing was a'), - (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'), - (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'The Sun is yellow because'), - (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.5), 'The Sun is yellow because'), + # (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'), + # (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'Alan Turing was a'), + # (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=30, diversity_penalty=1.0), 'Alan Turing was a'), + # (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'), + # (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'The Sun is yellow because'), + # (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.5), 'The Sun is yellow because'), ] @pytest.mark.parametrize("generation_config,prompt", test_cases) def test_greedy_decoding(model_fixture, generation_config, prompt): @@ -61,6 +64,7 @@ def test_greedy_decoding(model_fixture, generation_config, prompt): @pytest.mark.parametrize("max_new_tokens", [20, 15]) @pytest.mark.parametrize("diversity_penalty", [1.0, 1.5]) @pytest.mark.parametrize("prompt", prompts) +@pytest.mark.skip # temporarily def test_beam_search_decoding(model_fixture, num_beam_groups, group_size, max_new_tokens, diversity_penalty, prompt): generation_config = dict( @@ -76,6 +80,7 @@ def test_beam_search_decoding(model_fixture, num_beam_groups, group_size, @pytest.mark.parametrize("stop_criteria", ["never", "early", "heuristic"]) @pytest.mark.parametrize("prompt", prompts) @pytest.mark.parametrize("max_new_tokens", [20, 40, 300]) +@pytest.mark.skip # temporarily def test_stop_criteria(model_fixture, stop_criteria, prompt, max_new_tokens): # todo: for long sentences early stop_criteria fails if (stop_criteria == 'early' and max_new_tokens >= 300): From 73e4312de7a98f9f347421ad6903836a1810528b Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 19:14:22 +0400 Subject: [PATCH 12/40] Update Jinja2Cpp fork commit --- src/cpp/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index ffe28a81df..bfb5845553 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -13,7 +13,7 @@ FetchContent_MakeAvailable(nlohmann_json) function(ov_genai_build_jinja2cpp) FetchContent_Declare(jinja2cpp - URL https://github.com/ilya-lavrenov/Jinja2Cpp/archive/a5d002cbf44469775556daea14ba3ccdba1e365a.tar.gz + URL https://github.com/ilya-lavrenov/Jinja2Cpp/archive/5433af6b225cd35df700023cf60df4acdd6cbcf3.tar.gz URL_HASH SHA256=5aa5378d9acf3c44dfb607fd7f16f48b17ffa6495c219957901e9191ffe28900) FetchContent_GetProperties(jinja2cpp) From 54cbb5267d3adb9943fb1f482f422ad5fd631bd8 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 19:16:55 +0400 Subject: [PATCH 13/40] update URL_HASH --- src/cpp/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index bfb5845553..91efccf315 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -14,7 +14,7 @@ FetchContent_MakeAvailable(nlohmann_json) function(ov_genai_build_jinja2cpp) FetchContent_Declare(jinja2cpp URL https://github.com/ilya-lavrenov/Jinja2Cpp/archive/5433af6b225cd35df700023cf60df4acdd6cbcf3.tar.gz - URL_HASH SHA256=5aa5378d9acf3c44dfb607fd7f16f48b17ffa6495c219957901e9191ffe28900) + URL_HASH SHA256=b90f6c44908beaacae8eeb2690d11a6ebb183b4560434698ac00017e7bc07d11) FetchContent_GetProperties(jinja2cpp) if(NOT jinja2cpp_POPULATED) From 82a944910d9e804c5b34f01385571f88e4781dd0 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 19:19:18 +0400 Subject: [PATCH 14/40] remove submodules from .gitmodules --- .gitmodules | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.gitmodules b/.gitmodules index 937468fb64..f72fd83489 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,3 @@ [submodule "thirdparty/openvino_tokenizers"] path = thirdparty/openvino_tokenizers url = https://github.com/openvinotoolkit/openvino_tokenizers.git -[submodule "thirdparty/nlohmann_json"] - path = thirdparty/nlohmann_json - url = https://github.com/nlohmann/json.git -[submodule "thirdparty/Jinja2Cpp"] - path = thirdparty/Jinja2Cpp - url = https://github.com/jinja2cpp/Jinja2Cpp From 75b7c3799e5f24331db64d2c797fbf4acc369a89 Mon Sep 17 00:00:00 2001 From: Pavel Esir <pavel.esir@intel.com> Date: Wed, 22 May 2024 18:12:13 +0200 Subject: [PATCH 15/40] remove group_beam_searcher.hpp; copy fast_tokenizer --- src/cpp/CMakeLists.txt | 9 +++++++-- src/cpp/src/group_beam_searcher.cpp | 2 +- src/cpp/src/group_beam_searcher.hpp | 12 ------------ src/cpp/src/llm_pipeline.cpp | 3 ++- tests/python_tests/test_generate_api.py | 4 ++-- 5 files changed, 12 insertions(+), 18 deletions(-) delete mode 100644 src/cpp/src/group_beam_searcher.hpp diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index 91efccf315..e0151376b4 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -49,8 +49,6 @@ add_library(${TARGET_NAME} SHARED ${SOURCE_FILES}) add_library(openvino::${TARGET_NAME} ALIAS ${TARGET_NAME}) target_include_directories(${TARGET_NAME} - # TODO: remove it, because beam_search algo should not be exposed to end users - PRIVATE "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../../text_generation/causal_lm/cpp/>" PUBLIC "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>" "$<INSTALL_INTERFACE:runtime/include>") target_link_libraries(${TARGET_NAME} PUBLIC openvino::runtime PRIVATE nlohmann_json::nlohmann_json jinja2cpp) @@ -66,6 +64,13 @@ add_custom_command(TARGET ${TARGET_NAME} POST_BUILD "${CMAKE_CURRENT_SOURCE_DIR}/../python/openvino_genai/$<TARGET_FILE_NAME:${TARGET_NAME}>" COMMENT "Copy ${TARGET_NAME} to src/python/openvino_genai") +# Copy libcore_tokenizers.so to build_dir/openvino_tokenizers/src/ +add_custom_command(TARGET ${TARGET_NAME} POST_BUILD + COMMAND "${CMAKE_COMMAND}" -E copy + "${CMAKE_BINARY_DIR}/_deps/fast_tokenizer-src/lib/libcore_tokenizers.so" + "${CMAKE_BINARY_DIR}/openvino_tokenizers/src/" + COMMENT "Copy libcore_tokenizers.so to build_dir/openvino_tokenizers/src/") + install(TARGETS ${TARGET_NAME} LIBRARY DESTINATION . COMPONENT core_genai RUNTIME DESTINATION . COMPONENT core_genai) # - Windows: `<openvino_dir>\runtime\bin\intel64\Release\` diff --git a/src/cpp/src/group_beam_searcher.cpp b/src/cpp/src/group_beam_searcher.cpp index 1e27f36a0a..312671c8f0 100644 --- a/src/cpp/src/group_beam_searcher.cpp +++ b/src/cpp/src/group_beam_searcher.cpp @@ -2,8 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 #include <openvino/runtime/tensor.hpp> -#include "group_beam_searcher.hpp" #include "generation_config_helper.hpp" +#include "openvino/genai/llm_pipeline.hpp" #include "utils.hpp" namespace { diff --git a/src/cpp/src/group_beam_searcher.hpp b/src/cpp/src/group_beam_searcher.hpp deleted file mode 100644 index 5362c9cfae..0000000000 --- a/src/cpp/src/group_beam_searcher.hpp +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include <openvino/runtime/tensor.hpp> -#include "openvino/genai/generation_config.hpp" -#include "openvino/genai/llm_pipeline.hpp" - -namespace ov { - EncodedResults beam_search(ov::InferRequest& lm, ov::Tensor prompts, ov::Tensor attentin_mask, GenerationConfig config); -} diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 4415e507fe..9ea685e583 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -14,7 +14,6 @@ #include "openvino/genai/llm_pipeline.hpp" #include "utils.hpp" #include "generation_config_helper.hpp" -#include "group_beam_searcher.hpp" #include "text_callback_streamer.hpp" @@ -29,6 +28,8 @@ ov::EncodedResults greedy_decoding( bool is_chat_conversation = false ); +EncodedResults beam_search(ov::InferRequest& lm, ov::Tensor prompts, ov::Tensor attentin_mask, GenerationConfig config); + class LLMPipeline::LLMPipelineImpl { public: diff --git a/tests/python_tests/test_generate_api.py b/tests/python_tests/test_generate_api.py index d0bec03107..1d46e227c9 100644 --- a/tests/python_tests/test_generate_api.py +++ b/tests/python_tests/test_generate_api.py @@ -28,10 +28,10 @@ def run_hf_ov_genai_comparison(model_fixture, generation_config, prompt): hf_output = tokenizer.decode(hf_encoded_output[0, encoded_prompt.shape[1]:]) device = 'CPU' - ov_tokenziers_path = '../../build/openvino_tokenizers/src/' + ov_tokenizers_path = '../../build/openvino_tokenizers/src/' import openvino_genai as ov_genai - pipe = ov_genai.LLMPipeline(path, device, {}, ov_tokenziers_path) + pipe = ov_genai.LLMPipeline(path, device, {}, ov_tokenizers_path) ov_output = pipe.generate(prompt, **generation_config) if hf_output != ov_output: From b6cf9547688e82703277855b61d91a94f2b5a73f Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 20:16:23 +0400 Subject: [PATCH 16/40] rreorganaise components --- pyproject.toml | 2 +- src/cpp/CMakeLists.txt | 12 ++++-------- src/python/CMakeLists.txt | 7 ++++++- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cb373e12c8..f9707988bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ cmake.source-dir = "./" cmake.build-type = "Release" cmake.targets = ["py_generate_pipeline", "genai"] -install.components = ["core_genai", "pygenai"] +install.components = ["wheel_genai"] sdist.cmake = true wheel.packages = ["src/python/openvino_genai"] wheel.install-dir = "openvino_genai" diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index e43e2e9c1d..538d1fcc41 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -71,8 +71,6 @@ install(TARGETS ${TARGET_NAME} LIBRARY DESTINATION python/openvino_genai/ COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR} RUNTIME DESTINATION python/openvino_genai/ COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR}) -install(TARGETS ${TARGET_NAME} LIBRARY DESTINATION . COMPONENT pygenai RUNTIME DESTINATION . COMPONENT pygenai) - # - Windows: `<openvino_dir>\runtime\bin\intel64\Release\` # - MacOS_x86: `<openvino_dir>/runtime/lib/intel64/Release` # - MacOS_arm64: `<openvino_dir>/runtime/lib/arm64/Release/` @@ -94,14 +92,12 @@ endif() if(MSVC OR APPLE) set(ARCH_DIR ${ARCH_DIR}/${CMAKE_BUILD_TYPE}) endif() -install(TARGETS ${TARGET_NAME} - LIBRARY DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai - RUNTIME DESTINATION runtime/bin/${ARCH_DIR} COMPONENT core_genai) install(TARGETS ${TARGET_NAME} EXPORT openvino_genaiTargets - LIBRARY DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai_dev + LIBRARY DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai ARCHIVE DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai_dev - RUNTIME DESTINATION runtime/bin/${ARCH_DIR} COMPONENT core_genai_dev - INCLUDES DESTINATION runtime/include COMPONENT core_genai_dev) + RUNTIME DESTINATION runtime/bin/${ARCH_DIR} COMPONENT core_genai + NAMELINK_COMPONENT DESTINATION runtime/lib/${ARCH_DIR} COMPONENT HIDDEN + INCLUDES DESTINATION runtime/include) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ DESTINATION runtime/include COMPONENT core_genai_dev) install(EXPORT openvino_genaiTargets FILE openvino_genaiTargets.cmake NAMESPACE openvino:: DESTINATION runtime/cmake) include(CMakePackageConfigHelpers) diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 0ebdc78cd9..6fdf487c92 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -18,7 +18,6 @@ endif() pybind11_add_module(py_generate_pipeline py_generate_pipeline.cpp) target_link_libraries(py_generate_pipeline PRIVATE genai) -install(TARGETS py_generate_pipeline LIBRARY DESTINATION . COMPONENT pygenai) # setting RPATH / LC_RPATH depending on platform if(LINUX) @@ -47,3 +46,9 @@ add_custom_command(TARGET py_generate_pipeline POST_BUILD find_package(Python3 REQUIRED COMPONENTS Interpreter Development) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/openvino_genai/ DESTINATION python/openvino_genai/ COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR}) install(TARGETS py_generate_pipeline LIBRARY DESTINATION python/openvino_genai/ COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR}) + +if(SKBUILD) + # wheel_genai component is used for wheel generation in pyproject.toml + # Don't even add wheel_genai in other cases to hide it from normal packaging process + install(TARGETS genai py_generate_pipeline LIBRARY DESTINATION . COMPONENT wheel_genai RUNTIME DESTINATION . COMPONENT wheel_genai) +endif() From aaf5c78fee75f87d6d3b9d37e0378e4813446c1d Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 20:41:39 +0400 Subject: [PATCH 17/40] add SOVERSION, and requirements-build.txt --- requirements-build.txt | 2 ++ src/cpp/CMakeLists.txt | 5 +++++ 2 files changed, 7 insertions(+) create mode 100644 requirements-build.txt diff --git a/requirements-build.txt b/requirements-build.txt new file mode 100644 index 0000000000..0efe8fea72 --- /dev/null +++ b/requirements-build.txt @@ -0,0 +1,2 @@ +cmake +build diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index 538d1fcc41..5bbcfbaf98 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -59,6 +59,11 @@ target_compile_definitions(${TARGET_NAME} PRIVATE OPENVINO_TOKENIZERS_PATH=\"$<T target_compile_features(${TARGET_NAME} PUBLIC cxx_std_17) +set_target_properties(${TARGET_NAME} PROPERTIES + VERSION ${CMAKE_PROJECT_VERSION} + SOVERSION ${CMAKE_PROJECT_VERSION_MAJOR} +) + # Copy the library to python to allow skipping wheel installation add_custom_command(TARGET ${TARGET_NAME} POST_BUILD COMMAND "${CMAKE_COMMAND}" -E copy From 5537d3b66be45057c9ad76ef45b7825eec0a44b0 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 21:04:02 +0400 Subject: [PATCH 18/40] repalce SKBUILD with EXCLUDE_FROM_ALL because the effect is the same --- src/python/CMakeLists.txt | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 6fdf487c92..bf3efcd4b1 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -47,8 +47,6 @@ find_package(Python3 REQUIRED COMPONENTS Interpreter Development) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/openvino_genai/ DESTINATION python/openvino_genai/ COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR}) install(TARGETS py_generate_pipeline LIBRARY DESTINATION python/openvino_genai/ COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR}) -if(SKBUILD) - # wheel_genai component is used for wheel generation in pyproject.toml - # Don't even add wheel_genai in other cases to hide it from normal packaging process - install(TARGETS genai py_generate_pipeline LIBRARY DESTINATION . COMPONENT wheel_genai RUNTIME DESTINATION . COMPONENT wheel_genai) -endif() +# wheel_genai component is used for wheel generation in pyproject.toml. +# Exclude wheel_genai from normal packaging process. +install(TARGETS genai py_generate_pipeline LIBRARY DESTINATION . COMPONENT wheel_genai RUNTIME DESTINATION . COMPONENT wheel_genai EXCLUDE_FROM_ALL) From 9966be46f4ee88dff0e587127cb74ad5d12d8c4d Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 21:10:55 +0400 Subject: [PATCH 19/40] fix NAMELINK_COMPONENT --- src/cpp/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index 5bbcfbaf98..5af77dd3b4 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -99,9 +99,9 @@ if(MSVC OR APPLE) endif() install(TARGETS ${TARGET_NAME} EXPORT openvino_genaiTargets LIBRARY DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai + NAMELINK_COMPONENT core_genai_dev ARCHIVE DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai_dev RUNTIME DESTINATION runtime/bin/${ARCH_DIR} COMPONENT core_genai - NAMELINK_COMPONENT DESTINATION runtime/lib/${ARCH_DIR} COMPONENT HIDDEN INCLUDES DESTINATION runtime/include) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ DESTINATION runtime/include COMPONENT core_genai_dev) install(EXPORT openvino_genaiTargets FILE openvino_genaiTargets.cmake NAMESPACE openvino:: DESTINATION runtime/cmake) From 2486e53bef1a4601bee426088bccca874e72d7c3 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 22:04:36 +0400 Subject: [PATCH 20/40] remove extraline --- src/python/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index bf3efcd4b1..2bc0dea878 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -18,7 +18,6 @@ endif() pybind11_add_module(py_generate_pipeline py_generate_pipeline.cpp) target_link_libraries(py_generate_pipeline PRIVATE genai) - # setting RPATH / LC_RPATH depending on platform if(LINUX) # to find libgenai.so in the same folder From 786eac7fc974869e685c584b142df83d529c0678 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Wed, 22 May 2024 22:20:09 +0400 Subject: [PATCH 21/40] add soft restrictions --- requirements-build.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-build.txt b/requirements-build.txt index 0efe8fea72..aaaf7148ec 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -1,2 +1,2 @@ -cmake -build +cmake~=3.23 +build~=1.2.1 From 7324da9ae9383e9ce30af255585e49aa29f0ed15 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 00:43:51 +0400 Subject: [PATCH 22/40] Fix build to unblock packaging --- src/cpp/CMakeLists.txt | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index 2ebed86eae..6db300a4e2 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -75,11 +75,13 @@ install(TARGETS ${TARGET_NAME} RUNTIME DESTINATION python/openvino_genai/ COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR}) # Copy libcore_tokenizers.so to build_dir/openvino_tokenizers/src/ -add_custom_command(TARGET ${TARGET_NAME} POST_BUILD - COMMAND "${CMAKE_COMMAND}" -E copy - "${CMAKE_BINARY_DIR}/_deps/fast_tokenizer-src/lib/libcore_tokenizers.so" - "${CMAKE_BINARY_DIR}/openvino_tokenizers/src/" - COMMENT "Copy libcore_tokenizers.so to build_dir/openvino_tokenizers/src/") +if(NOT MSVC) + add_custom_command(TARGET ${TARGET_NAME} POST_BUILD + COMMAND "${CMAKE_COMMAND}" -E copy + "${CMAKE_BINARY_DIR}/_deps/fast_tokenizer-src/lib/libcore_tokenizers.so" + "${CMAKE_BINARY_DIR}/openvino_tokenizers/src/" + COMMENT "Copy libcore_tokenizers.so to build_dir/openvino_tokenizers/src/") +endif() # - Windows: `<openvino_dir>\runtime\bin\intel64\Release\` # - MacOS_x86: `<openvino_dir>/runtime/lib/intel64/Release` From 5577e84899561c97d5f099ba80c395f36aa9f16f Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 11:52:47 +0400 Subject: [PATCH 23/40] improve naming --- CMakeLists.txt | 5 +++-- pyproject.toml | 2 +- src/cpp/CMakeLists.txt | 12 ++++++------ ...iConfig.cmake.in => OpenVINOGenAIConfig.cmake.in} | 2 +- src/python/CMakeLists.txt | 6 +++--- text_generation/causal_lm/cpp/CMakeLists.txt | 2 +- 6 files changed, 15 insertions(+), 14 deletions(-) rename src/cpp/{openvino_genaiConfig.cmake.in => OpenVINOGenAIConfig.cmake.in} (70%) diff --git a/CMakeLists.txt b/CMakeLists.txt index ac392233a6..6c01b378c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,13 +13,14 @@ if(NOT GENERATOR_IS_MULTI_CONFIG_VAR AND NOT DEFINED CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel ...") endif() -project(openvino_genai VERSION 2024.2.0.0) +project(OpenVINOGenAI VERSION 2024.2.0.0) add_subdirectory(./thirdparty/openvino_tokenizers/ "${CMAKE_CURRENT_BINARY_DIR}/openvino_tokenizers/") add_subdirectory(src) add_subdirectory(text_generation/causal_lm/cpp) install(DIRECTORY text_generation/causal_lm/cpp/ DESTINATION samples/cpp/causal_lm COMPONENT cpp_samples_genai) -install(FILES LICENSE third-party-programs.txt DESTINATION licensing_genai COMPONENT licensing_genai) # TODO: how to merge with OPenvino +install(FILES LICENSE DESTINATION licensing COMPONENT licensing_genai RENAME LICENSE-GENAI) +install(FILES third-party-programs.txt DESTINATION licensing COMPONENT licensing_genai RENAME third-party-programs-genai.txt) set(CPACK_GENERATOR "ZIP") include(CPack) diff --git a/pyproject.toml b/pyproject.toml index f9707988bf..3354bf3e70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ [tool.scikit-build] cmake.source-dir = "./" cmake.build-type = "Release" -cmake.targets = ["py_generate_pipeline", "genai"] +cmake.targets = ["py_generate_pipeline", "openvino::genai"] install.components = ["wheel_genai"] sdist.cmake = true wheel.packages = ["src/python/openvino_genai"] diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index 6db300a4e2..30d95d3553 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -104,17 +104,17 @@ endif() if(MSVC OR APPLE) set(ARCH_DIR ${ARCH_DIR}/${CMAKE_BUILD_TYPE}) endif() -install(TARGETS ${TARGET_NAME} EXPORT openvino_genaiTargets +install(TARGETS ${TARGET_NAME} EXPORT OpenVINOGenAITargets LIBRARY DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai NAMELINK_COMPONENT core_genai_dev ARCHIVE DESTINATION runtime/lib/${ARCH_DIR} COMPONENT core_genai_dev RUNTIME DESTINATION runtime/bin/${ARCH_DIR} COMPONENT core_genai INCLUDES DESTINATION runtime/include) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ DESTINATION runtime/include COMPONENT core_genai_dev) -install(EXPORT openvino_genaiTargets FILE openvino_genaiTargets.cmake NAMESPACE openvino:: DESTINATION runtime/cmake) +install(EXPORT OpenVINOGenAITargets FILE OpenVINOGenAITargets.cmake NAMESPACE openvino:: DESTINATION runtime/cmake) include(CMakePackageConfigHelpers) -configure_package_config_file(openvino_genaiConfig.cmake.in "${CMAKE_BINARY_DIR}/openvino_genaiConfig.cmake" INSTALL_DESTINATION runtime/cmake) -install(FILES "${CMAKE_BINARY_DIR}/openvino_genaiConfig.cmake" "${CMAKE_BINARY_DIR}/openvino_genaiConfigVersion.cmake" DESTINATION runtime/cmake COMPONENT core_genai_dev) +configure_package_config_file(OpenVINOGenAIConfig.cmake.in "${CMAKE_BINARY_DIR}/OpenVINOGenAIConfig.cmake" INSTALL_DESTINATION runtime/cmake) +install(FILES "${CMAKE_BINARY_DIR}/OpenVINOGenAIConfig.cmake" "${CMAKE_BINARY_DIR}/OpenVINOGenAIConfig.cmake" DESTINATION runtime/cmake COMPONENT core_genai_dev) include(CMakePackageConfigHelpers) -write_basic_package_version_file("${CMAKE_BINARY_DIR}/openvino_genaiConfigVersion.cmake" VERSION ${CMAKE_PROJECT_VERSION} COMPATIBILITY AnyNewerVersion) -export(EXPORT openvino_genaiTargets FILE "${CMAKE_BINARY_DIR}/openvino_genaiTargets.cmake" NAMESPACE openvino::) +write_basic_package_version_file("${CMAKE_BINARY_DIR}/OpenVINOGenAIConfigVersion.cmake" VERSION ${CMAKE_PROJECT_VERSION} COMPATIBILITY AnyNewerVersion) +export(EXPORT OpenVINOGenAITargets FILE "${CMAKE_BINARY_DIR}/OpenVINOGenAITargets.cmake" NAMESPACE openvino::) diff --git a/src/cpp/openvino_genaiConfig.cmake.in b/src/cpp/OpenVINOGenAIConfig.cmake.in similarity index 70% rename from src/cpp/openvino_genaiConfig.cmake.in rename to src/cpp/OpenVINOGenAIConfig.cmake.in index abfd33b524..18c0bb4e48 100644 --- a/src/cpp/openvino_genaiConfig.cmake.in +++ b/src/cpp/OpenVINOGenAIConfig.cmake.in @@ -4,7 +4,7 @@ include(CMakeFindDependencyMacro) find_dependency(OpenVINO COMPONENTS Runtime) if(NOT TARGET genai) - include("${CMAKE_CURRENT_LIST_DIR}/openvino_genaiTargets.cmake") + include("${CMAKE_CURRENT_LIST_DIR}/OpenVINOGenAITargets.cmake") endif() check_required_components(openvino_genai) diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 2bc0dea878..00722b6fff 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -5,8 +5,8 @@ include(FetchContent) FetchContent_Declare( pybind11 - GIT_REPOSITORY https://github.com/pybind/pybind11 - GIT_TAG v2.12.0 + URL https://github.com/pybind/pybind11/archive/3e9dfa2866941655c56877882565e7577de6fc7b.tar.gz + URL_HASH SHA256=9a7d245f405f470798b9d2a48912cc97230658024775299eac203f7c9c9ae37c ) set(CMAKE_POSITION_INDEPENDENT_CODE ON) FetchContent_GetProperties(pybind11) @@ -16,7 +16,7 @@ if(NOT pybind11_POPULATED) endif() pybind11_add_module(py_generate_pipeline py_generate_pipeline.cpp) -target_link_libraries(py_generate_pipeline PRIVATE genai) +target_link_libraries(py_generate_pipeline PRIVATE openvino::genai) # setting RPATH / LC_RPATH depending on platform if(LINUX) diff --git a/text_generation/causal_lm/cpp/CMakeLists.txt b/text_generation/causal_lm/cpp/CMakeLists.txt index 1998c3ccb6..8c57d65fae 100644 --- a/text_generation/causal_lm/cpp/CMakeLists.txt +++ b/text_generation/causal_lm/cpp/CMakeLists.txt @@ -10,7 +10,7 @@ else() set(OPENVINO_TOKENIZERS_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../bin/openvino_tokenizers.dll) # TODO: I'll go away after the generate() gets a way to find openvino_tokenizers endif() -find_package(openvino_genai REQUIRED PATHS +find_package(OpenVINOGenAI REQUIRED PATHS "${CMAKE_BINARY_DIR}" # Reuse the package from the build. ${OpenVINO_DIR} # GenAI may be installed alogside OpenVINO. ) From b679fc70c10b04f1cd1043f7439c26b05259da1d Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 12:01:03 +0400 Subject: [PATCH 24/40] install samples --- .github/workflows/genai_package.yml | 8 ++++---- .github/workflows/genai_python_lib.yml | 2 +- text_generation/causal_lm/cpp/CMakeLists.txt | 5 +++++ 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.github/workflows/genai_package.yml b/.github/workflows/genai_package.yml index 65d3e6da3e..8a92b4f492 100644 --- a/.github/workflows/genai_package.yml +++ b/.github/workflows/genai_package.yml @@ -17,11 +17,11 @@ jobs: - run: source ./ov/setupvars.sh && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - run: source ./ov/setupvars.sh && cmake --build ./build/ --config Release --target package -j - run: source ./ov/setupvars.sh && cmake --install ./build/ --config Release --prefix ov - - run: ov/samples/cpp/build_samples.sh -b "${{ github.workspace }}/s pace" + - run: ov/samples/cpp/build_samples.sh -i ${{ github.workspace }}/s\ pace - run: source ./ov/setupvars.sh && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: source ./ov/setupvars.sh && optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 - - run: source ./ov/setupvars.sh && timeout 50s "${{ github.workspace }}/s pace/intel64/Release/greedy_causal_lm" ./TinyLlama-1.1B-Chat-v1.0/ "" + - run: source ./ov/setupvars.sh && timeout 50s "${{ github.workspace }}/s\ pace/samples_bin/greedy_causal_lm" ./TinyLlama-1.1B-Chat-v1.0/ "" windows_genai_package: runs-on: windows-latest @@ -40,8 +40,8 @@ jobs: - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --build ./build/ --config Release --target package -j - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --install ./build/ --config Release --prefix w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64 - - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\samples\cpp\build_samples_msvc.bat -b "${{ github.workspace }}/samples_build" + - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\samples\cpp\build_samples_msvc.bat -i "${{ github.workspace }}/samples_install" - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 - - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && "${{ github.workspace }}/samples_build/intel64/Release/greedy_causal_lm" .\TinyLlama-1.1B-Chat-v1.0\ "" + - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && "${{ github.workspace }}/samples_install/samples_bin/greedy_causal_lm" .\TinyLlama-1.1B-Chat-v1.0\ "" diff --git a/.github/workflows/genai_python_lib.yml b/.github/workflows/genai_python_lib.yml index 29f537858a..da63cc5f29 100644 --- a/.github/workflows/genai_python_lib.yml +++ b/.github/workflows/genai_python_lib.yml @@ -2,7 +2,7 @@ name: genai_python_lib on: pull_request jobs: ubuntu_genai_python_lib: - runs-on: ubuntu-20.04-16-cores + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v4 with: diff --git a/text_generation/causal_lm/cpp/CMakeLists.txt b/text_generation/causal_lm/cpp/CMakeLists.txt index 8c57d65fae..7e3ec23fde 100644 --- a/text_generation/causal_lm/cpp/CMakeLists.txt +++ b/text_generation/causal_lm/cpp/CMakeLists.txt @@ -51,3 +51,8 @@ target_link_libraries(chat_sample PRIVATE openvino::genai) target_include_directories(chat_sample PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}") set_target_properties(chat_sample PROPERTIES CXX_STANDARD 17) set_target_properties(chat_sample PROPERTIES CXX_STANDARD_REQUIRED ON) + +install(TARGETS greedy_causal_lm beam_search_causal_lm speculative_decoding_lm prompt_lookup_decoding_lm chat_sample + RUNTIME DESTINATION samples_bin/ + COMPONENT samples_bin + EXCLUDE_FROM_ALL) From 26f9fe1b5e1b1087c860491ce9dd0620859b30b5 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 12:08:29 +0400 Subject: [PATCH 25/40] remove quotes --- .github/workflows/genai_package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/genai_package.yml b/.github/workflows/genai_package.yml index 8a92b4f492..abf2f09bb7 100644 --- a/.github/workflows/genai_package.yml +++ b/.github/workflows/genai_package.yml @@ -21,7 +21,7 @@ jobs: - run: source ./ov/setupvars.sh && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: source ./ov/setupvars.sh && optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 - - run: source ./ov/setupvars.sh && timeout 50s "${{ github.workspace }}/s\ pace/samples_bin/greedy_causal_lm" ./TinyLlama-1.1B-Chat-v1.0/ "" + - run: source ./ov/setupvars.sh && timeout 50s ${{ github.workspace }}/s\ pace/samples_bin/greedy_causal_lm ./TinyLlama-1.1B-Chat-v1.0/ "" windows_genai_package: runs-on: windows-latest From 1dcd40ba1ccee46ac9adbe569c5eb874af64a577 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 12:10:33 +0400 Subject: [PATCH 26/40] use main target name because an alias can't be specified in cmake --target --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3354bf3e70..f9707988bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ [tool.scikit-build] cmake.source-dir = "./" cmake.build-type = "Release" -cmake.targets = ["py_generate_pipeline", "openvino::genai"] +cmake.targets = ["py_generate_pipeline", "genai"] install.components = ["wheel_genai"] sdist.cmake = true wheel.packages = ["src/python/openvino_genai"] From 8c00ccb7c647d9f8ccf3ffbe2eb4a051e499a778 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 12:55:23 +0400 Subject: [PATCH 27/40] define CMAKE_BUILD_PARALLEL_LEVEL --- .github/workflows/genai_python_lib.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/genai_python_lib.yml b/.github/workflows/genai_python_lib.yml index da63cc5f29..99426eddc6 100644 --- a/.github/workflows/genai_python_lib.yml +++ b/.github/workflows/genai_python_lib.yml @@ -18,7 +18,7 @@ jobs: - run: python -m pip install --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly # Can't load CentOS libraries from the archive - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: PYTHONPATH=./src/python/ python -c "from openvino_genai import LLMPipeline" - - run: source ./ov/setupvars.sh && python -m pip install --pre . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly + - run: source ./ov/setupvars.sh && CMAKE_BUILD_PARALLEL_LEVEL= python -m pip install --pre . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly - run: python -c "from openvino_genai import LLMPipeline" - name: GenAI Python API tests run: | @@ -50,5 +50,5 @@ jobs: - run: python -m pip install "numpy<1.27" - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: set "PYTHONPATH=./src/python;" && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -c "from openvino_genai import LLMPipeline" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. - - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install . + - run: set CMAKE_BUILD_PARALLEL_LEVEL= && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat python -m pip install . - run: python -c "from openvino_genai import LLMPipeline" From 61fba583545a1025c00ccf081fc8b7a5b05adb56 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 13:03:56 +0400 Subject: [PATCH 28/40] Ensure ./requirements-build.txt won't outdate --- .github/dependabot.yml | 4 ++++ .github/workflows/genai_package.yml | 8 ++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 9ab4587c2a..789167949f 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,5 +1,9 @@ version: 2 updates: + - package-ecosystem: "pip" + directory: "./" + schedule: + interval: "weekly" - package-ecosystem: "pip" directory: "image_generation/stable_diffusion_1_5/cpp/scripts/" schedule: diff --git a/.github/workflows/genai_package.yml b/.github/workflows/genai_package.yml index abf2f09bb7..34f6cebf51 100644 --- a/.github/workflows/genai_package.yml +++ b/.github/workflows/genai_package.yml @@ -18,7 +18,9 @@ jobs: - run: source ./ov/setupvars.sh && cmake --build ./build/ --config Release --target package -j - run: source ./ov/setupvars.sh && cmake --install ./build/ --config Release --prefix ov - run: ov/samples/cpp/build_samples.sh -i ${{ github.workspace }}/s\ pace - - run: source ./ov/setupvars.sh && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt + # GitHub Actions already provides what is listed in ./requirements-build.txt but the internal + # build system doesn't. Install ./requirements-build.txt to detect possible conflicts. + - run: source ./ov/setupvars.sh && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt -r ./requirements-build.txt - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: source ./ov/setupvars.sh && optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 - run: source ./ov/setupvars.sh && timeout 50s ${{ github.workspace }}/s\ pace/samples_bin/greedy_causal_lm ./TinyLlama-1.1B-Chat-v1.0/ "" @@ -41,7 +43,9 @@ jobs: - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --build ./build/ --config Release --target package -j - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --install ./build/ --config Release --prefix w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64 - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\samples\cpp\build_samples_msvc.bat -i "${{ github.workspace }}/samples_install" - - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt + # GitHub Actions already provides what is listed in ./requirements-build.txt but the internal + # build system doesn't. Install ./requirements-build.txt to detect possible conflicts. + - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt -r ./requirements-build.txt - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && "${{ github.workspace }}/samples_install/samples_bin/greedy_causal_lm" .\TinyLlama-1.1B-Chat-v1.0\ "" From d78fa3b012dbf932b997b9b1f23ccc279836ab92 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 13:08:07 +0400 Subject: [PATCH 29/40] Use ./requirements-build.txt in python lib build --- .github/workflows/genai_package.yml | 8 ++------ .github/workflows/genai_python_lib.yml | 8 ++++++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/genai_package.yml b/.github/workflows/genai_package.yml index 34f6cebf51..abf2f09bb7 100644 --- a/.github/workflows/genai_package.yml +++ b/.github/workflows/genai_package.yml @@ -18,9 +18,7 @@ jobs: - run: source ./ov/setupvars.sh && cmake --build ./build/ --config Release --target package -j - run: source ./ov/setupvars.sh && cmake --install ./build/ --config Release --prefix ov - run: ov/samples/cpp/build_samples.sh -i ${{ github.workspace }}/s\ pace - # GitHub Actions already provides what is listed in ./requirements-build.txt but the internal - # build system doesn't. Install ./requirements-build.txt to detect possible conflicts. - - run: source ./ov/setupvars.sh && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt -r ./requirements-build.txt + - run: source ./ov/setupvars.sh && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: source ./ov/setupvars.sh && optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 - run: source ./ov/setupvars.sh && timeout 50s ${{ github.workspace }}/s\ pace/samples_bin/greedy_causal_lm ./TinyLlama-1.1B-Chat-v1.0/ "" @@ -43,9 +41,7 @@ jobs: - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --build ./build/ --config Release --target package -j - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --install ./build/ --config Release --prefix w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64 - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\samples\cpp\build_samples_msvc.bat -i "${{ github.workspace }}/samples_install" - # GitHub Actions already provides what is listed in ./requirements-build.txt but the internal - # build system doesn't. Install ./requirements-build.txt to detect possible conflicts. - - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt -r ./requirements-build.txt + - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && "${{ github.workspace }}/samples_install/samples_bin/greedy_causal_lm" .\TinyLlama-1.1B-Chat-v1.0\ "" diff --git a/.github/workflows/genai_python_lib.yml b/.github/workflows/genai_python_lib.yml index 99426eddc6..b23645d285 100644 --- a/.github/workflows/genai_python_lib.yml +++ b/.github/workflows/genai_python_lib.yml @@ -16,7 +16,9 @@ jobs: - run: source ./ov/setupvars.sh && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - run: source ./ov/setupvars.sh && cmake --build ./build/ --config Release -j - run: python -m pip install --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly # Can't load CentOS libraries from the archive - - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] + # GitHub Actions already provides what is listed in ./requirements-build.txt but the internal + # build system doesn't. Install ./requirements-build.txt to detect possible conflicts. + - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./requirements-build.txt - run: PYTHONPATH=./src/python/ python -c "from openvino_genai import LLMPipeline" - run: source ./ov/setupvars.sh && CMAKE_BUILD_PARALLEL_LEVEL= python -m pip install --pre . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly - run: python -c "from openvino_genai import LLMPipeline" @@ -48,7 +50,9 @@ jobs: - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --build ./build/ --config Release -j - run: python -m pip install "numpy<1.27" - - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] + # GitHub Actions already provides what is listed in ./requirements-build.txt but the internal + # build system doesn't. Install ./requirements-build.txt to detect possible conflicts. + - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./requirements-build.txt - run: set "PYTHONPATH=./src/python;" && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -c "from openvino_genai import LLMPipeline" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. - run: set CMAKE_BUILD_PARALLEL_LEVEL= && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat python -m pip install . - run: python -c "from openvino_genai import LLMPipeline" From 757b73801dae245e66038311f62cce3c69216bf5 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 13:23:25 +0400 Subject: [PATCH 30/40] Add missing && --- .github/workflows/genai_python_lib.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/genai_python_lib.yml b/.github/workflows/genai_python_lib.yml index b23645d285..b5afeeb6b3 100644 --- a/.github/workflows/genai_python_lib.yml +++ b/.github/workflows/genai_python_lib.yml @@ -54,5 +54,5 @@ jobs: # build system doesn't. Install ./requirements-build.txt to detect possible conflicts. - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./requirements-build.txt - run: set "PYTHONPATH=./src/python;" && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -c "from openvino_genai import LLMPipeline" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. - - run: set CMAKE_BUILD_PARALLEL_LEVEL= && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat python -m pip install . + - run: set CMAKE_BUILD_PARALLEL_LEVEL= && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install . - run: python -c "from openvino_genai import LLMPipeline" From 51ace23487c21a5a1751a904992234accb135331 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 13:44:18 +0400 Subject: [PATCH 31/40] Test Debug --- .github/workflows/genai_package.yml | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/.github/workflows/genai_package.yml b/.github/workflows/genai_package.yml index abf2f09bb7..ba958a9983 100644 --- a/.github/workflows/genai_package.yml +++ b/.github/workflows/genai_package.yml @@ -2,6 +2,9 @@ name: genai_package on: pull_request jobs: ubuntu_genai_package: + strategy: + matrix: + build-type: [Release, Debug] runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v4 @@ -14,14 +17,19 @@ jobs: - run: curl https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.1/linux/l_openvino_toolkit_ubuntu20_2024.1.0.15008.f4afc983258_x86_64.tgz | tar --directory ./ov/ --strip-components 1 -xz - run: sudo ./ov/install_dependencies/install_openvino_dependencies.sh - run: sudo apt-get install libtbb-dev - - run: source ./ov/setupvars.sh && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - - run: source ./ov/setupvars.sh && cmake --build ./build/ --config Release --target package -j - - run: source ./ov/setupvars.sh && cmake --install ./build/ --config Release --prefix ov + - run: source ./ov/setupvars.sh && cmake -DCMAKE_BUILD_TYPE=${{ matrix.build-type }} -S ./ -B ./build/ + - run: source ./ov/setupvars.sh && cmake --build ./build/ --config ${{ matrix.build-type }} --target package -j + - run: source ./ov/setupvars.sh && cmake --install ./build/ --config ${{ matrix.build-type }} --prefix ov - run: ov/samples/cpp/build_samples.sh -i ${{ github.workspace }}/s\ pace + if: ${{ 'Release' == matrix.build-type }} # build_samples enforces Release build - run: source ./ov/setupvars.sh && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt + if: ${{ 'Release' == matrix.build-type }} - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] + if: ${{ 'Release' == matrix.build-type }} - run: source ./ov/setupvars.sh && optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 + if: ${{ 'Release' == matrix.build-type }} - run: source ./ov/setupvars.sh && timeout 50s ${{ github.workspace }}/s\ pace/samples_bin/greedy_causal_lm ./TinyLlama-1.1B-Chat-v1.0/ "" + if: ${{ 'Release' == matrix.build-type }} windows_genai_package: runs-on: windows-latest @@ -37,11 +45,16 @@ jobs: python-version: 3.8 - run: curl --output ov.zip https://storage.openvinotoolkit.org/repositories/openvino/packages/nightly/2024.2.0-15349-765302e0de1/w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64.zip - run: unzip ov.zip - - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --build ./build/ --config Release --target package -j - - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --install ./build/ --config Release --prefix w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64 + - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake -DCMAKE_BUILD_TYPE=${{ matrix.build-type }} -S ./ -B ./build/ + - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --build ./build/ --config ${{ matrix.build-type }} --target package -j + - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && cmake --install ./build/ --config ${{ matrix.build-type }} --prefix w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64 - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\samples\cpp\build_samples_msvc.bat -i "${{ github.workspace }}/samples_install" + if: ${{ 'Release' == matrix.build-type }} # build_samples enforces Release build - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install --upgrade-strategy eager -r text_generation/causal_lm/cpp/requirements.txt + if: ${{ 'Release' == matrix.build-type }} - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] + if: ${{ 'Release' == matrix.build-type }} - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 + if: ${{ 'Release' == matrix.build-type }} - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && "${{ github.workspace }}/samples_install/samples_bin/greedy_causal_lm" .\TinyLlama-1.1B-Chat-v1.0\ "" + if: ${{ 'Release' == matrix.build-type }} From e53c525b9f4ab8097fb8d2b6567f83f6af62505c Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 13:49:19 +0400 Subject: [PATCH 32/40] add matrix for windows_genai_package --- .github/workflows/genai_package.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/genai_package.yml b/.github/workflows/genai_package.yml index ba958a9983..42ef1da025 100644 --- a/.github/workflows/genai_package.yml +++ b/.github/workflows/genai_package.yml @@ -32,6 +32,9 @@ jobs: if: ${{ 'Release' == matrix.build-type }} windows_genai_package: + strategy: + matrix: + build-type: [Release, Debug] runs-on: windows-latest defaults: run: From 73ac7b1228d4ffb85c0350f834a9c4d894691bbc Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 14:09:06 +0400 Subject: [PATCH 33/40] openvino_tokenizers from form --- thirdparty/openvino_tokenizers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thirdparty/openvino_tokenizers b/thirdparty/openvino_tokenizers index c754503462..d106306c7c 160000 --- a/thirdparty/openvino_tokenizers +++ b/thirdparty/openvino_tokenizers @@ -1 +1 @@ -Subproject commit c754503462f569b648b598d57ff91ea57bb8deb1 +Subproject commit d106306c7c35ff1b85d977b0fe23b7861c168b56 From e7e50cbb9efee568d9c1de5487b05b9ad6d5f186 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 14:13:51 +0400 Subject: [PATCH 34/40] update openvino_tokenizers --- thirdparty/openvino_tokenizers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thirdparty/openvino_tokenizers b/thirdparty/openvino_tokenizers index d106306c7c..c7aee81172 160000 --- a/thirdparty/openvino_tokenizers +++ b/thirdparty/openvino_tokenizers @@ -1 +1 @@ -Subproject commit d106306c7c35ff1b85d977b0fe23b7861c168b56 +Subproject commit c7aee81172fcfe526d7c2d2821464e8b3bc765c2 From 33394076ba3c7fa86946f9fd6da888dcde9f2c22 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 14:19:36 +0400 Subject: [PATCH 35/40] update openvino_tokenizers --- thirdparty/openvino_tokenizers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thirdparty/openvino_tokenizers b/thirdparty/openvino_tokenizers index c7aee81172..f451741c4a 160000 --- a/thirdparty/openvino_tokenizers +++ b/thirdparty/openvino_tokenizers @@ -1 +1 @@ -Subproject commit c7aee81172fcfe526d7c2d2821464e8b3bc765c2 +Subproject commit f451741c4a19444088c3fff6c04c4e2af8f87af5 From 9b5b9152eacd8552a22d084c0810d0401d939437 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 14:35:18 +0400 Subject: [PATCH 36/40] update openvino_tokenizers --- thirdparty/openvino_tokenizers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thirdparty/openvino_tokenizers b/thirdparty/openvino_tokenizers index f451741c4a..6c9c6b8cd4 160000 --- a/thirdparty/openvino_tokenizers +++ b/thirdparty/openvino_tokenizers @@ -1 +1 @@ -Subproject commit f451741c4a19444088c3fff6c04c4e2af8f87af5 +Subproject commit 6c9c6b8cd4b2fadb827a46e64c2572b01192d6ef From 1fe85b91c4b6892299dbaa5a07c97c19355b7121 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 15:23:00 +0400 Subject: [PATCH 37/40] revert openvino_tokenizers --- thirdparty/openvino_tokenizers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thirdparty/openvino_tokenizers b/thirdparty/openvino_tokenizers index 6c9c6b8cd4..c754503462 160000 --- a/thirdparty/openvino_tokenizers +++ b/thirdparty/openvino_tokenizers @@ -1 +1 @@ -Subproject commit 6c9c6b8cd4b2fadb827a46e64c2572b01192d6ef +Subproject commit c754503462f569b648b598d57ff91ea57bb8deb1 From 7e2393057e58bf7c8f36982be1d42531e0319fac Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 15:58:58 +0400 Subject: [PATCH 38/40] tokenizers from fork --- .github/workflows/genai_python_lib.yml | 4 ++-- thirdparty/openvino_tokenizers | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/genai_python_lib.yml b/.github/workflows/genai_python_lib.yml index b5afeeb6b3..f00ce286aa 100644 --- a/.github/workflows/genai_python_lib.yml +++ b/.github/workflows/genai_python_lib.yml @@ -20,7 +20,7 @@ jobs: # build system doesn't. Install ./requirements-build.txt to detect possible conflicts. - run: source ./ov/setupvars.sh && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./requirements-build.txt - run: PYTHONPATH=./src/python/ python -c "from openvino_genai import LLMPipeline" - - run: source ./ov/setupvars.sh && CMAKE_BUILD_PARALLEL_LEVEL= python -m pip install --pre . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly + - run: source ./ov/setupvars.sh && CMAKE_BUILD_PARALLEL_LEVEL="" python -m pip install --pre . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly - run: python -c "from openvino_genai import LLMPipeline" - name: GenAI Python API tests run: | @@ -54,5 +54,5 @@ jobs: # build system doesn't. Install ./requirements-build.txt to detect possible conflicts. - run: call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./requirements-build.txt - run: set "PYTHONPATH=./src/python;" && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -c "from openvino_genai import LLMPipeline" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. - - run: set CMAKE_BUILD_PARALLEL_LEVEL= && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install . + - run: set CMAKE_BUILD_PARALLEL_LEVEL="" && call w_openvino_toolkit_windows_2024.2.0.dev20240515_x86_64\setupvars.bat && python -m pip install . - run: python -c "from openvino_genai import LLMPipeline" diff --git a/thirdparty/openvino_tokenizers b/thirdparty/openvino_tokenizers index c754503462..eb5abc686b 160000 --- a/thirdparty/openvino_tokenizers +++ b/thirdparty/openvino_tokenizers @@ -1 +1 @@ -Subproject commit c754503462f569b648b598d57ff91ea57bb8deb1 +Subproject commit eb5abc686b306e58e0e3055be7ec3fc98393d88c From 62f5e34befe12fe78b6716a152712943a6e074a8 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 16:52:28 +0400 Subject: [PATCH 39/40] update tokenizers --- thirdparty/openvino_tokenizers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thirdparty/openvino_tokenizers b/thirdparty/openvino_tokenizers index eb5abc686b..b45b752edf 160000 --- a/thirdparty/openvino_tokenizers +++ b/thirdparty/openvino_tokenizers @@ -1 +1 @@ -Subproject commit eb5abc686b306e58e0e3055be7ec3fc98393d88c +Subproject commit b45b752edf0245f65bcc0c2c6925b771fe55c4b5 From 63262d738ace4084dbe6fd02db0b2fd78e718c96 Mon Sep 17 00:00:00 2001 From: Wovchena <vladimir.zlobin@intel.com> Date: Thu, 23 May 2024 17:04:23 +0400 Subject: [PATCH 40/40] centos7_2024.2.0.dev --- .github/workflows/genai_package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/genai_package.yml b/.github/workflows/genai_package.yml index 42ef1da025..e618586d26 100644 --- a/.github/workflows/genai_package.yml +++ b/.github/workflows/genai_package.yml @@ -14,7 +14,7 @@ jobs: with: python-version: 3.8 - run: mkdir ./ov/ - - run: curl https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.1/linux/l_openvino_toolkit_ubuntu20_2024.1.0.15008.f4afc983258_x86_64.tgz | tar --directory ./ov/ --strip-components 1 -xz + - run: curl https://storage.openvinotoolkit.org/repositories/openvino/packages/nightly/2024.2.0-15454-0d95325972f/l_openvino_toolkit_centos7_2024.2.0.dev20240522_x86_64.tgz | tar --directory ./ov/ --strip-components 1 -xz - run: sudo ./ov/install_dependencies/install_openvino_dependencies.sh - run: sudo apt-get install libtbb-dev - run: source ./ov/setupvars.sh && cmake -DCMAKE_BUILD_TYPE=${{ matrix.build-type }} -S ./ -B ./build/