1
1
// Copyright (C) 2023-2024 Intel Corporation
2
2
// SPDX-License-Identifier: Apache-2.0
3
3
4
- #include " visual_language/image_embedder .hpp"
4
+ #include " visual_language/inputs_embedder .hpp"
5
5
6
6
#include " visual_language/clip.hpp"
7
7
#include " visual_language/vision_encoder.hpp"
@@ -905,56 +905,60 @@ class InputsEmbedderInternVLChat : public InputsEmbedder::IInputsEmbedder {
905
905
IInputsEmbedder (vlm_config, model_dir, device, device_config) { }
906
906
907
907
virtual ov::Tensor get_inputs_embeds (const std::string& prompt, const std::vector<ov::Tensor>& images) override {
908
- if (images.empty ()) {
909
- ov::Tensor input_ids = get_encoded_input_ids (prompt);
910
- return m_embedding.infer (input_ids);
911
- } else {
912
- OPENVINO_ASSERT (1 == images.size (), " Only a single image allowed" );
913
- EncodedImage encoded_image = m_vision_encoder.encode (images.at (0 ));
914
- ov::Tensor image_embeds = encoded_image.resized_source ;
915
-
916
- std::string image_start_token = m_vlm_config.image_start_token ;
917
- std::string image_context_token = m_vlm_config.image_context_token ;
918
- std::string image_end_token = m_vlm_config.image_end_token ;
908
+ std::string image_start_token = m_vlm_config.image_start_token ;
909
+ std::string image_context_token = m_vlm_config.image_context_token ;
910
+ std::string image_end_token = m_vlm_config.image_end_token ;
911
+
912
+ std::vector<ov::Tensor> single_images = to_single_image_tensors (images);
913
+
914
+ std::string formatted_prompt;
915
+ std::vector<ov::Tensor> image_embeds;
916
+ image_embeds.reserve (single_images.size ());
917
+
918
+ for (const auto & image : single_images) {
919
+ EncodedImage encoded_image = m_vision_encoder.encode (image);
920
+ ov::Tensor single_image_embeds = encoded_image.resized_source ;
919
921
920
- const size_t num_patches = image_embeds .get_shape ().at (0 );
921
- const size_t num_image_tokens = image_embeds .get_shape ().at (1 );
922
+ const size_t num_patches = single_image_embeds .get_shape ().at (0 );
923
+ const size_t num_image_tokens = single_image_embeds .get_shape ().at (1 );
922
924
923
- std::string concated_image_tokens;
924
- concated_image_tokens += image_start_token;
925
+ formatted_prompt += image_start_token;
925
926
for (int i = 0 ; i < num_patches * num_image_tokens; ++i) {
926
- concated_image_tokens += image_context_token;
927
+ formatted_prompt += image_context_token;
927
928
}
928
- concated_image_tokens += image_end_token;
929
+ formatted_prompt += image_end_token + " \n " ;
929
930
930
- std::string formatted_prompt = concated_image_tokens + " \n " + prompt;
931
-
932
- ov::Tensor input_ids = get_encoded_input_ids (formatted_prompt);
933
- ov::Tensor text_embeds = m_embedding.infer (input_ids);
931
+ image_embeds.push_back (std::move (single_image_embeds));
932
+ }
933
+ formatted_prompt += prompt;
934
934
935
- ov::Tensor encoded_image_context_token = m_tokenizer. encode (image_context_token, ov::genai::add_special_tokens ( false )). input_ids ;
936
- int64_t image_context_token_id = encoded_image_context_token. data < int64_t >()[encoded_image_context_token. get_size () - 1 ] ;
935
+ ov::Tensor input_ids = get_encoded_input_ids (formatted_prompt) ;
936
+ ov::Tensor text_embeds = m_embedding. infer (input_ids) ;
937
937
938
- return merge_text_and_image_embeddings_internvl (input_ids, text_embeds, image_embeds, image_context_token_id);
938
+ if (images.empty ()) {
939
+ return text_embeds;
939
940
}
941
+
942
+ ov::Tensor encoded_image_context_token = m_tokenizer.encode (image_context_token, ov::genai::add_special_tokens (false )).input_ids ;
943
+ int64_t image_context_token_id = encoded_image_context_token.data <int64_t >()[encoded_image_context_token.get_size () - 1 ];
944
+
945
+ return merge_text_and_image_embeddings_internvl (input_ids, text_embeds, image_embeds, image_context_token_id);
940
946
}
941
947
942
948
protected:
943
949
ov::Tensor merge_text_and_image_embeddings_internvl (
944
950
const ov::Tensor& input_ids,
945
951
const ov::Tensor& text_embeds,
946
- const ov::Tensor& image_embeds,
952
+ const std::vector< ov::Tensor> & image_embeds,
947
953
int64_t image_context_token_id
948
954
) {
949
955
auto text_embeds_shape = text_embeds.get_shape ();
950
- auto image_embeds_shape = image_embeds.get_shape ();
951
956
size_t batch_size = text_embeds_shape.at (0 );
952
957
size_t seq_len = text_embeds_shape.at (1 );
953
958
size_t embed_dim = text_embeds_shape.at (2 );
954
959
955
960
ov::Tensor merged_embeds (text_embeds.get_element_type (), text_embeds_shape);
956
961
957
- const float * image_embeds_data = image_embeds.data <float >();
958
962
const float * text_embeds_data = text_embeds.data <float >();
959
963
const int64_t * input_ids_data = input_ids.data <int64_t >();
960
964
float * merged_embeds_data = merged_embeds.data <float >();
@@ -972,15 +976,27 @@ class InputsEmbedderInternVLChat : public InputsEmbedder::IInputsEmbedder {
972
976
973
977
OPENVINO_ASSERT (image_context_tokens_count > 0 , " input_ids does not contain image context token ids" );
974
978
975
- size_t vision_idx = 0 ;
979
+ size_t image_idx = 0 ;
980
+ size_t image_context_token_idx = 0 ;
976
981
for (size_t i = 0 ; i < batch_size; ++i) {
977
982
for (size_t j = 0 ; j < seq_len; ++j) {
978
983
size_t flat_idx = i * seq_len + j;
979
984
size_t offset = flat_idx * embed_dim;
980
985
981
986
if (image_context_tokens_mask[flat_idx]) {
982
- std::copy_n (image_embeds_data + vision_idx * embed_dim, embed_dim, merged_embeds_data + offset);
983
- ++vision_idx;
987
+ const ov::Tensor& single_image_embeds = image_embeds[image_idx];
988
+ const size_t num_all_image_tokens = single_image_embeds.get_shape ().at (0 ) * single_image_embeds.get_shape ().at (1 ); // num_patches * num_image_tokens
989
+ const float * image_embeds_data = single_image_embeds.data <float >();
990
+ std::copy_n (image_embeds_data + image_context_token_idx * embed_dim,
991
+ embed_dim,
992
+ merged_embeds_data + offset);
993
+
994
+ ++image_context_token_idx;
995
+
996
+ if (image_context_token_idx == num_all_image_tokens) {
997
+ ++image_idx;
998
+ image_context_token_idx = 0 ;
999
+ }
984
1000
} else {
985
1001
std::copy_n (text_embeds_data + offset, embed_dim, merged_embeds_data + offset);
986
1002
}
0 commit comments