Skip to content

Commit c9c8beb

Browse files
committed
test for video
1 parent f2e3135 commit c9c8beb

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

optimum/intel/openvino/modeling_visual_language.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from pathlib import Path
88
from types import MethodType
9-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
9+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1010

1111
import numpy as np
1212
import openvino as ov
@@ -55,7 +55,11 @@
5555

5656
if TYPE_CHECKING:
5757
from PIL.Image import Image
58-
from transformers.image_utils import VideoInput
58+
59+
if is_transformers_version(">=", "4.42.0"):
60+
from transformers.image_utils import VideoInput
61+
else:
62+
VideoInput = List[Image]
5963

6064

6165
logger = logging.getLogger(__name__)

tests/openvino/test_modeling.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import torch
3232
from datasets import load_dataset
3333
from evaluate import evaluator
34-
from huggingface_hub import HfApi
34+
from huggingface_hub import HfApi, hf_hub_download
3535
from parameterized import parameterized
3636
from PIL import Image
3737
from sentence_transformers import SentenceTransformer
@@ -2126,21 +2126,25 @@ def test_compare_with_and_without_past_key_values(self):
21262126

21272127
class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
21282128
SUPPORTED_ARCHITECTURES = ["llava"]
2129+
SUPPORT_VIDEO = []
21292130

2130-
# if is_transformers_version(">=", "4.40.0"):
2131-
# SUPPORTED_ARCHITECTURES += ["llava_next", "nanollava"]
2131+
if is_transformers_version(">=", "4.40.0"):
2132+
SUPPORTED_ARCHITECTURES += ["llava_next", "nanollava"]
21322133

21332134
if is_transformers_version(">=", "4.42.0"):
21342135
SUPPORTED_ARCHITECTURES += ["llava_next_video"]
2136+
SUPPORT_VIDEO.append("llava_next_video")
21352137

2136-
# if is_transformers_version(">=", "4.45.0"):
2137-
# SUPPORTED_ARCHITECTURES += ["minicpmv", "internvl2", "phi3_v", "qwen2_vl"]
2138+
if is_transformers_version(">=", "4.45.0"):
2139+
SUPPORTED_ARCHITECTURES += ["minicpmv", "internvl2", "phi3_v", "qwen2_vl"]
2140+
SUPPORT_VIDEO.append("qwen2_vl")
21382141

2139-
# if is_transformers_version(">=", "4.46.0"):
2140-
# SUPPORTED_ARCHITECTURES += ["maira2"]
2142+
if is_transformers_version(">=", "4.46.0"):
2143+
SUPPORTED_ARCHITECTURES += ["maira2"]
21412144

2142-
# if is_transformers_version(">=", "4.49.0"):
2143-
# SUPPORTED_ARCHITECTURES += ["qwen2_5_vl"]
2145+
if is_transformers_version(">=", "4.49.0"):
2146+
SUPPORTED_ARCHITECTURES += ["qwen2_5_vl"]
2147+
SUPPORT_VIDEO.append("qwen2_5_vl")
21442148
TASK = "image-text-to-text"
21452149
REMOTE_CODE_MODELS = ["internvl2", "minicpmv", "nanollava", "phi3_v", "maira2"]
21462150

@@ -2350,6 +2354,22 @@ def test_generate_utils(self, model_arch):
23502354
outputs = outputs[:, inputs["input_ids"].shape[1] :]
23512355
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
23522356
self.assertIsInstance(outputs[0], str)
2357+
2358+
# video laoder helper only available for transformers >= 4.49
2359+
if model_arch in self.SUPPORT_VIDEO and is_transformers_version(">=", "4.49"):
2360+
from transformers.image_utils import load_video
2361+
2362+
video_path = hf_hub_download(
2363+
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
2364+
)
2365+
input_video = load_video(video_path, num_frames=4)
2366+
question = "Why is this video funny?"
2367+
inputs = model.preprocess_inputs(**preprocessors, text=question, video=input_video)
2368+
outputs = model.generate(**inputs, max_new_tokens=10)
2369+
# filter out original prompt becuase it may contains out of tokenizer tokens e.g. in nanollva text separator = -200
2370+
outputs = outputs[:, inputs["input_ids"].shape[1] :]
2371+
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
2372+
self.assertIsInstance(outputs[0], str)
23532373
del model
23542374

23552375
gc.collect()

0 commit comments

Comments
 (0)