@@ -747,26 +747,22 @@ def _prepare_visual_causal_lm_dataset(self, config: OVWeightQuantizationConfig,
747
747
pbar = tqdm (desc = "Collecting calibration dataset" , total = num_samples )
748
748
for item in dataset :
749
749
image_url = item [dataset_metadata ["inputs" ]["image_url" ]]
750
- instruction = item [dataset_metadata ["inputs" ]["instruction" ]]
751
750
image = Image .open (requests .get (image_url , stream = True ).raw )
752
751
752
+ instruction = item [dataset_metadata ["inputs" ]["instruction" ]]
753
753
chat_template = [{"role" : "user" , "content" : [{"type" : "text" , "text" : instruction }, {"type" : "image" }]}]
754
754
prompt = processor .apply_chat_template (chat_template , add_generation_prompt = True )
755
-
756
755
inputs = processor (images = image , text = prompt , return_tensors = "pt" )
757
- if inputs .input_ids .size (1 ) > max_tokens :
758
- continue
759
756
input_ids = inputs .input_ids
760
- attention_mask = inputs .attention_mask
761
- position_ids = torch .arange (attention_mask .size (1 )).unsqueeze (0 ).to (attention_mask .device )
762
- pixel_values = inputs .pixel_values
763
- image_sizes = inputs .image_sizes
757
+ if input_ids .size (1 ) > max_tokens :
758
+ continue
764
759
760
+ position_ids = torch .arange (inputs .input_ids .size (1 )).unsqueeze (0 ).to (inputs .input_ids .device )
765
761
inputs_embeds , attention_mask , position_ids = self .model .get_multimodal_embeddings (
766
762
input_ids ,
767
- pixel_values ,
768
- image_sizes = image_sizes ,
769
- attention_mask = attention_mask ,
763
+ inputs . pixel_values ,
764
+ image_sizes = inputs . image_sizes ,
765
+ attention_mask = inputs . attention_mask ,
770
766
position_ids = position_ids ,
771
767
)
772
768
@@ -776,6 +772,7 @@ def _prepare_visual_causal_lm_dataset(self, config: OVWeightQuantizationConfig,
776
772
position_ids = position_ids ,
777
773
inputs_embeds = inputs_embeds ,
778
774
)
775
+
779
776
pbar .update (1 )
780
777
calibration_dataset .append (language_model_inputs )
781
778
if len (calibration_dataset ) == num_samples :
0 commit comments