|
16 | 16 | #include "visual_language/internvl_chat/classes.hpp"
|
17 | 17 |
|
18 | 18 | #include "utils.hpp"
|
19 |
| -#include <regex> |
20 |
| - |
21 |
| -namespace { |
22 |
| - |
23 |
| -std::regex UNIVERSAL_PATTERN{R"(<ov_genai_image_(\d+)>)"}; |
24 |
| - |
25 |
| -} |
26 | 19 |
|
27 | 20 | namespace ov::genai {
|
28 | 21 |
|
@@ -296,49 +289,46 @@ bool InputsEmbedder::prompt_has_image_tag(const std::string& prompt) const {
|
296 | 289 | return m_impl->prompt_has_image_tag(prompt);
|
297 | 290 | }
|
298 | 291 |
|
299 |
| -std::pair<std::string, std::vector<size_t>> unify_prompt( |
| 292 | +void verify_ids(const std::vector<size_t>& image_ids, size_t base_id, size_t n_images) { |
| 293 | + for (size_t idx : image_ids) { |
| 294 | + OPENVINO_ASSERT(base_id <= idx, "Referring to older images isn't implemented"); |
| 295 | + OPENVINO_ASSERT(idx < base_id + n_images, "Missing image ", idx); |
| 296 | + } |
| 297 | +} |
| 298 | + |
| 299 | +std::pair<std::string, std::vector<size_t>> normalize_prompt( |
300 | 300 | const std::string& prompt,
|
301 | 301 | const std::string& native_tag,
|
302 |
| - const std::string& unified_tag_to_native_tag, |
303 |
| - size_t n_new_images, |
304 |
| - size_t first_new_image_id |
| 302 | + const std::string& automatic_tag, |
| 303 | + size_t base_id, |
| 304 | + size_t n_images |
305 | 305 | ) {
|
306 |
| - bool found_universal_tag = std::regex_search(prompt, UNIVERSAL_PATTERN); |
307 |
| - bool found_native_tag = prompt.find(native_tag) != std::string::npos; |
308 |
| - OPENVINO_ASSERT(!(found_universal_tag && found_native_tag), "Prompt can contain only one type of image tags."); |
309 |
| - std::stringstream images_prompt; |
310 |
| - if (!found_universal_tag && ! found_native_tag) { |
311 |
| - for (size_t i = first_new_image_id; i < n_new_images + first_new_image_id; ++i) { |
312 |
| - images_prompt << "<ov_genai_image_" << i << ">"; |
313 |
| - } |
| 306 | + size_t pos = prompt.find(native_tag); |
| 307 | + auto [image_prompt, image_sequence] = universal_to_native(prompt, [&](std::ostream& os, size_t) { |
| 308 | + os << automatic_tag; |
| 309 | + }); |
| 310 | + if (!image_sequence.empty()) { |
| 311 | + OPENVINO_ASSERT(pos == std::string::npos, "Prompt can contain only one type of image tags."); |
| 312 | + verify_ids(image_sequence, base_id, n_images); |
| 313 | + return {std::move(image_prompt), std::move(image_sequence)}; |
314 | 314 | }
|
315 |
| - images_prompt << prompt; |
316 |
| - |
317 |
| - std::vector<size_t> images_sequence; |
318 |
| - std::string unified_prompt = images_prompt.str(); |
319 |
| - std::sregex_iterator end_it; |
320 |
| - if (found_native_tag) { |
321 |
| - size_t pos = 0; |
322 |
| - while ((pos = unified_prompt.find(native_tag, pos)) != std::string::npos) { |
323 |
| - images_sequence.push_back(first_new_image_id + images_sequence.size()); |
324 |
| - pos += native_tag.length(); |
325 |
| - } |
326 |
| - OPENVINO_ASSERT(images_sequence.size() == n_new_images); |
327 |
| - } else { |
328 |
| - bool found = true; |
329 |
| - while (found) { |
330 |
| - found = false; |
331 |
| - for (std::sregex_iterator it(unified_prompt.begin(), unified_prompt.end(), UNIVERSAL_PATTERN); it != end_it; ++it) { |
332 |
| - images_sequence.push_back(std::stoi((*it)[1].str())); |
333 |
| - OPENVINO_ASSERT(images_sequence.back() < n_new_images + first_new_image_id, "Missing image ", images_sequence.back()); |
334 |
| - OPENVINO_ASSERT(first_new_image_id <= images_sequence.back(), "Referring to older images isn't implemented"); |
335 |
| - unified_prompt.replace(it->position(), it->length(), unified_tag_to_native_tag); |
336 |
| - found = true; |
337 |
| - break; |
338 |
| - } |
339 |
| - } |
| 315 | + // Restore ids from native tags |
| 316 | + while (pos != std::string::npos) { |
| 317 | + image_sequence.push_back(base_id + image_sequence.size()); |
| 318 | + pos = prompt.find(native_tag, pos + native_tag.length()); |
| 319 | + } |
| 320 | + if (!image_sequence.empty()) { |
| 321 | + OPENVINO_ASSERT(image_sequence.size() == n_images, "The number of native image tags and provided images must match because it's ambiguous which image should be ignored."); |
| 322 | + return {std::move(image_prompt), std::move(image_sequence)}; |
| 323 | + } |
| 324 | + // Prepend automatic tags |
| 325 | + std::stringstream stream; |
| 326 | + for (size_t relative_id = 0; relative_id < n_images; relative_id++) { |
| 327 | + image_sequence.push_back(base_id + relative_id); |
| 328 | + stream << automatic_tag; |
340 | 329 | }
|
341 |
| - return {std::move(unified_prompt), std::move(images_sequence)}; |
| 330 | + stream << prompt; |
| 331 | + return {stream.str(), std::move(image_sequence)}; |
342 | 332 | }
|
343 | 333 |
|
344 | 334 | } // namespace ov::genai
|
0 commit comments