Skip to content

Commit 3d0e5ba

Browse files
committed
Temp
1 parent be9e203 commit 3d0e5ba

File tree

2 files changed

+60
-62
lines changed

2 files changed

+60
-62
lines changed

samples/cpp/visual_language_chat/visual_language_chat.cpp

+12-12
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ bool print_subword(std::string&& subword) {
99
return !(std::cout << subword << std::flush);
1010
}
1111

12-
int main(int argc, char* argv[]) try {
12+
int main(int argc, char* argv[]) {
1313
if (3 != argc) {
1414
throw std::runtime_error(std::string{"Usage "} + argv[0] + " <MODEL_DIR> <IMAGE_FILE>");
1515
}
@@ -31,7 +31,7 @@ int main(int argc, char* argv[]) try {
3131
}
3232
pipe.generate(
3333
prompt,
34-
ov::genai::image(std::move(image)),
34+
ov::genai::images(std::vector{image, image}),
3535
ov::genai::streamer(print_subword)
3636
);
3737
std::cout << "\n----------\n"
@@ -42,14 +42,14 @@ int main(int argc, char* argv[]) try {
4242
"question:\n";
4343
}
4444
pipe.finish_chat();
45-
} catch (const std::exception& error) {
46-
try {
47-
std::cerr << error.what() << '\n';
48-
} catch (const std::ios_base::failure&) {}
49-
return EXIT_FAILURE;
50-
} catch (...) {
51-
try {
52-
std::cerr << "Non-exception object thrown\n";
53-
} catch (const std::ios_base::failure&) {}
54-
return EXIT_FAILURE;
45+
// } catch (const std::exception& error) {
46+
// try {
47+
// std::cerr << error.what() << '\n';
48+
// } catch (const std::ios_base::failure&) {}
49+
// return EXIT_FAILURE;
50+
// } catch (...) {
51+
// try {
52+
// std::cerr << "Non-exception object thrown\n";
53+
// } catch (const std::ios_base::failure&) {}
54+
// return EXIT_FAILURE;
5555
}

src/cpp/src/vlm_pipeline.cpp

+48-50
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,9 @@ DecodedResults VLMPipeline::generate(
338338
const StreamerVariant& streamer
339339
) {
340340
std::string images_prompt;
341-
EncodedImage embeds;
342-
if (!rgbs.empty()) {
343-
OPENVINO_ASSERT(1 == rgbs.size(), "TODO: Only a single image allowed");
344-
embeds = m_vision_encoder.encode(rgbs.at(0));
341+
std::vector<EncodedImage> embeds;
342+
for (const ov::Tensor& rgb : rgbs) {
343+
EncodedImage encoded_image = m_vision_encoder.encode(rgb);
345344
if (m_vlm_config.use_image_id) {
346345
images_prompt = m_vlm_config.im_id_start + std::to_string(image_id) + m_vlm_config.im_id_end;
347346
++image_id;
@@ -351,8 +350,8 @@ DecodedResults VLMPipeline::generate(
351350
unk64 += m_vlm_config.unk;
352351
}
353352
images_prompt += m_vlm_config.im_start + unk64 + m_vlm_config.im_end;
354-
if (embeds.slices) {
355-
ov::Shape slices_shape = embeds.slices.get_shape();
353+
if (encoded_image.slices) {
354+
ov::Shape slices_shape = encoded_image.slices.get_shape();
356355
for (size_t row_idx = 0; row_idx < slices_shape.at(0); ++row_idx) {
357356
for (size_t col_idx = 0; col_idx < slices_shape.at(1); ++col_idx) {
358357
images_prompt += m_vlm_config.slice_start + unk64 + m_vlm_config.slice_end;
@@ -365,6 +364,7 @@ DecodedResults VLMPipeline::generate(
365364
// Strangely, \n isn't placed between </image><slice>.
366365
images_prompt += '\n';
367366
}
367+
embeds.push_back(std::move(encoded_image));
368368
}
369369
images_prompt += prompt;
370370
ov::Tensor encoded_input;
@@ -402,36 +402,34 @@ DecodedResults VLMPipeline::generate(
402402
m_vlm_config.hidden_size == inputs_embeds.get_shape().at(2),
403403
"Unexpected embedding size"
404404
);
405-
if (!rgbs.empty()) {
406-
ov::Tensor special_tokens = m_tokenizer.encode(
407-
m_vlm_config.im_start
408-
+ m_vlm_config.im_end
409-
+ m_vlm_config.slice_start
410-
+ m_vlm_config.slice_end
411-
).input_ids;
412-
OPENVINO_ASSERT(
413-
4 == special_tokens.get_shape().at(1),
414-
"Every special token must be represented with a single int."
415-
);
416-
size_t im_start_id = special_tokens.data<int64_t>()[0];
417-
size_t im_end_id = special_tokens.data<int64_t>()[1];
418-
size_t slice_start_id = special_tokens.data<int64_t>()[2];
419-
size_t slice_end_id = special_tokens.data<int64_t>()[3];
420-
int64_t* ids = encoded_input.data<int64_t>();
421-
const ov::Tensor& resampled_source = resample(*this, embeds.resized_source, {embeds.resized_source_size});
405+
ov::Tensor special_tokens = m_tokenizer.encode(
406+
m_vlm_config.im_start
407+
+ m_vlm_config.im_end
408+
+ m_vlm_config.slice_start
409+
+ m_vlm_config.slice_end
410+
).input_ids;
411+
OPENVINO_ASSERT(
412+
4 == special_tokens.get_shape().at(1),
413+
"Every special token must be represented with a single int."
414+
);
415+
size_t im_start_id = special_tokens.data<int64_t>()[0];
416+
size_t im_end_id = special_tokens.data<int64_t>()[1];
417+
size_t slice_start_id = special_tokens.data<int64_t>()[2];
418+
size_t slice_end_id = special_tokens.data<int64_t>()[3];
419+
size_t im_start_pos = 0, slice_start_pos = 0;
420+
int64_t* begin = encoded_input.data<int64_t>();
421+
int64_t* ids = begin;
422+
size_t encoded_input_size = encoded_input.get_size();
423+
const int64_t* end = ids + encoded_input_size;
424+
float* input_embeds_data = input_embeds.data<float>();
425+
for (const EncodedImage& encoded_image : embeds) {
426+
const ov::Tensor& resampled_source = resample(*this, encoded_image.resized_source, {encoded_image.resized_source_size});
422427
float* emb = resampled_source.data<float>();
423-
bool replacing = false;
424-
for (size_t token_idx = 0; token_idx < inputs_embeds.get_shape().at(1); ++token_idx) {
425-
if (im_start_id == ids[token_idx]) {
426-
replacing = true;
427-
}
428-
if (replacing) {
429-
std::copy_n(emb, resampled_source.get_size(), inputs_embeds.data<float>() + token_idx * m_vlm_config.hidden_size);
430-
token_idx += resampled_source.get_shape().at(1);
431-
replacing = false;
432-
break;
433-
}
428+
ids = std::find(ids, end, im_start_id);
429+
if (end == ids) {
430+
break;
434431
}
432+
ids = std::copy_n(emb, resampled_source.get_size(), input_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size);
435433
if (embeds.slices) {
436434
size_t token_idx = 0;
437435
const ov::Shape& slices_shape = embeds.slices.get_shape();
@@ -442,21 +440,11 @@ DecodedResults VLMPipeline::generate(
442440
size_t d3 = slices_shape.at(3);
443441
ov::Tensor encoded_view{ov::element::f32, {1, d2, d3}, embeds.slices.data<float>() + (i * slices_shape.at(1) + ja) * d2 * d3};
444442
const ov::Tensor& vision_embed_tensor_i_j = resample(*this, encoded_view, {sliced_sizes.at(i * slices_shape.at(1) + ja)});
445-
for (; token_idx < inputs_embeds.get_shape().at(1); ++token_idx) {
446-
if (slice_start_id == ids[token_idx]) {
447-
replacing = true;
448-
}
449-
if (slice_end_id == ids[token_idx]) {
450-
replacing = false;
451-
break;
452-
}
453-
if (replacing) {
454-
std::copy_n(vision_embed_tensor_i_j.data<float>(), vision_embed_tensor_i_j.get_size(), inputs_embeds.data<float>() + token_idx * m_vlm_config.hidden_size);
455-
token_idx += vision_embed_tensor_i_j.get_shape().at(1);
456-
replacing = false;
457-
break;
458-
}
443+
ids = std::find(ids, end, slice_start_id);
444+
if (end == ids) {
445+
break;
459446
}
447+
ids = std::copy_n(vision_embed_tensor_i_j.data<float>(), vision_embed_tensor_i_j.get_size(), input_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size);
460448
}
461449
}
462450
}
@@ -552,13 +540,23 @@ DecodedResults VLMPipeline::generate(
552540
const ov::AnyMap& config_map
553541
) {
554542
auto image = config_map.find(ov::genai::image.name());
543+
auto images = config_map.find(ov::genai::images.name());
544+
OPENVINO_ASSERT(
545+
config_map.end() == image || config_map.end() == images,
546+
"Only one property can be set: image of images."
547+
);
548+
std::vector<ov::Tensor> rgbs;
549+
if (config_map.end() != image) {
550+
rgbs = {image->second.as<ov::Tensor>()};
551+
} if (config_map.end() != images) {
552+
rgbs = images->second.as<std::vector<ov::Tensor>>();
553+
}
555554
ov::genai::OptionalGenerationConfig config_arg = utils::get_config_from_map(config_map);
556555
GenerationConfig config = (config_arg.has_value()) ? *config_arg : get_generation_config();
557556
config.update_generation_config(config_map);
558557
return generate(
559558
prompt,
560-
config_map.end() == image ? std::vector<ov::Tensor>{}
561-
: std::vector{image->second.as<ov::Tensor>()},
559+
rgbs,
562560
config,
563561
utils::get_streamer_from_map(config_map)
564562
);

0 commit comments

Comments
 (0)