Skip to content

Commit 2dbf413

Browse files
committed
add tests
1 parent 9043078 commit 2dbf413

File tree

5 files changed

+25
-13
lines changed

5 files changed

+25
-13
lines changed

optimum/exporters/openvino/model_patcher.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -3988,7 +3988,16 @@ def __init__(
39883988

39893989
@functools.wraps(model.__orig_forward)
39903990
def patched_forward(*args, **kwargs):
3991-
return model.model.forward(*args, **kwargs)
3991+
fwd_args = inspect.signature(model.__orig_forward).parameters
3992+
internal_fwd_args = inspect.signature(model.model.forward).parameters
3993+
inputs = {}
3994+
for arg, fwd_arg_name in zip(args, fwd_args):
3995+
if fwd_arg_name in internal_fwd_args:
3996+
inputs[fwd_arg_name] = arg
3997+
for key, value in kwargs.items():
3998+
if key in internal_fwd_args:
3999+
inputs[key] = value
4000+
return model.model.forward(**inputs)
39924001

39934002
model.forward = patched_forward
39944003
self._internal_patcher = internal_patcher

optimum/exporters/openvino/stateful.py

-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model):
290290
openvino model
291291
"""
292292

293-
log.warn(ov_model)
294293
key_value_input_names = [
295294
key_name for key in ov_model.inputs for key_name in key.get_names() if "key_values" in key_name
296295
]

optimum/intel/openvino/modeling_visual_language.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,13 @@ def __init__(
349349
language_model: ov.Model,
350350
text_embeddings: ov.Model,
351351
vision_embeddings: ov.Model,
352-
lm_head: Optional[ov.Model] = None,
353352
config: PretrainedConfig = None,
354353
device: str = "CPU",
355354
dynamic_shapes: bool = True,
356355
ov_config: Optional[Dict[str, str]] = None,
357356
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
358357
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
358+
lm_head: Optional[ov.Model] = None,
359359
**kwargs,
360360
):
361361
self.config = config
@@ -717,6 +717,9 @@ def components(self):
717717
def _submodel_names(self):
718718
model_names = ["lm_model", "text_embeddings_model", "vision_embeddings_model"]
719719
for part in self.additional_parts:
720+
if part == "lm_head" and getattr(self, part + "_model", None) is not None:
721+
model_names.append(part + "_model")
722+
continue
720723
if getattr(self, part, None) is not None:
721724
model_names.append(part + "_model")
722725
return model_names
@@ -2472,6 +2475,7 @@ def generate_image(
24722475
image_token_num_per_image: int = 576,
24732476
img_size: int = 384,
24742477
patch_size: int = 16,
2478+
generator=None
24752479
):
24762480
from PIL import Image
24772481

@@ -2520,7 +2524,7 @@ def generate_image(
25202524
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
25212525
probs = torch.softmax(logits / temperature, dim=-1)
25222526

2523-
next_token = torch.multinomial(probs, num_samples=1)
2527+
next_token = torch.multinomial(probs, num_samples=1) if generator is None else torch.multinomial(probs, num_samples=1, generator=generator)
25242528
generated_tokens[:, i] = next_token.squeeze(dim=-1)
25252529

25262530
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
@@ -2563,11 +2567,10 @@ def preprocess_inputs(
25632567
},
25642568
{"role": "<|Assistant|>", "content": ""},
25652569
]
2566-
prompt = None
2570+
prepare_inputs = processor(conversations=conversation, images=[image], force_batchify=True)
25672571
else:
2568-
conversation = None
2569-
prompt = text
2570-
prepare_inputs = processor(prompt=prompt, conversations=conversation, images=[image], force_batchify=True)
2572+
tokenizer = tokenizer if tokenizer is not None else processor.tokenizer
2573+
prepare_inputs = tokenizer(text, return_tensors="pt")
25712574
required_keys = ["input_ids", "pixel_values", "images_seq_mask", "images_emb_mask"]
25722575
inputs = {}
25732576
for key in required_keys:

tests/openvino/test_modeling.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -2114,10 +2114,10 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
21142114
if is_transformers_version(">=", "4.40.0"):
21152115
SUPPORTED_ARCHITECTURES += ["llava_next", "nanollava"]
21162116
if is_transformers_version(">=", "4.45.0"):
2117-
SUPPORTED_ARCHITECTURES += ["minicpmv", "internvl2", "phi3_v", "qwen2_vl", "maira2"]
2118-
TASK = "image-text-to-text"
2119-
REMOTE_CODE_MODELS = ["internvl2", "minicpmv", "nanollava", "phi3_v", "maira2"]
21202117

2118+
SUPPORTED_ARCHITECTURES += ["janus", "minicpmv", "internvl2", "phi3_v", "qwen2_vl", "maira2"]
2119+
TASK = "image-text-to-text"
2120+
REMOTE_CODE_MODELS = ["internvl2", "minicpmv", "nanollava", "phi3_v", "maira2", "janus"]
21212121
IMAGE = Image.open(
21222122
requests.get(
21232123
TEST_IMAGE_URL,
@@ -2216,8 +2216,8 @@ def test_compare_to_transformers(self, model_arch):
22162216
with torch.no_grad():
22172217
transformers_outputs = transformers_model.generate(**transformers_inputs, generation_config=gen_config)
22182218

2219-
# original minicpmv, internvl always skip input tokens in generation results, while transformers based approach provide them
2220-
if model_arch in ["minicpmv", "internvl2"]:
2219+
# original minicpmv, internvl, janus always skip input tokens in generation results, while transformers based approach provide them
2220+
if model_arch in ["minicpmv", "internvl2", "janus"]:
22212221
ov_outputs = ov_outputs[:, inputs["input_ids"].shape[1] :]
22222222
self.assertTrue(
22232223
torch.equal(ov_outputs, transformers_outputs),

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@
170170
"st-bert": "sentence-transformers/all-MiniLM-L6-v2",
171171
"st-mpnet": "sentence-transformers/all-mpnet-base-v2",
172172
"sana": "katuni4ka/tiny-random-sana",
173+
"janus": "katuni4ka/tiny-random-janus"
173174
}
174175

175176

0 commit comments

Comments
 (0)