Commit 8ab6a84 1 parent c08b95d commit 8ab6a84 Copy full SHA for 8ab6a84
File tree 1 file changed +2
-2
lines changed
1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -1243,7 +1243,7 @@ def merge_vision_text_embeddings(
1243
1243
1244
1244
# Whether to turn off right padding
1245
1245
# 1. Create a mask to know where special image tokens are
1246
- special_image_token_mask = torch . tensor ( input_ids == image_token_index )
1246
+ special_image_token_mask = input_ids == image_token_index
1247
1247
# special_image_token_mask: [bsz, seqlen]
1248
1248
num_special_image_tokens = torch .sum (special_image_token_mask , dim = - 1 )
1249
1249
# num_special_image_tokens: [bsz]
@@ -1336,7 +1336,7 @@ def merge_vision_text_embeddings(
1336
1336
final_attention_mask |= image_to_overwrite
1337
1337
position_ids = (final_attention_mask .cumsum (- 1 ) - 1 ).masked_fill_ ((final_attention_mask == 0 ), 1 )
1338
1338
else :
1339
- special_image_mask = torch . tensor (( input_ids == image_token_index ) ).unsqueeze (- 1 ).expand_as (inputs_embeds )
1339
+ special_image_mask = ( input_ids == image_token_index ).unsqueeze (- 1 ).expand_as (inputs_embeds )
1340
1340
image_features = image_features .to (inputs_embeds .dtype )
1341
1341
final_embedding = inputs_embeds .masked_scatter (special_image_mask , image_features )
1342
1342
final_attention_mask = attention_mask
You can’t perform that action at this time.
0 commit comments