|
24 | 24 | PretrainedConfig,
|
25 | 25 | PreTrainedTokenizer,
|
26 | 26 | )
|
27 |
| -from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutputWithPast |
| 27 | +from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling |
28 | 28 | from transformers.utils import ModelOutput
|
29 | 29 |
|
30 | 30 | from ...exporters.openvino import main_export
|
@@ -125,7 +125,6 @@ def _compile_lm_head(self):
|
125 | 125 | self.lm_head_model, self._device, self.ov_config, self.model_save_dir
|
126 | 126 | )
|
127 | 127 |
|
128 |
| - |
129 | 128 | def clear_requests(self):
|
130 | 129 | if self._compile_only:
|
131 | 130 | raise ValueError(
|
@@ -235,11 +234,13 @@ def forward(
|
235 | 234 | if self.lm_head_request is not None:
|
236 | 235 | last_hidden_state = self.request.get_tensor("last_hidden_state").data
|
237 | 236 | if include_head:
|
238 |
| - logits = self.lm_head_request(logits)[0] |
| 237 | + logits = self.lm_head_request(last_hidden_state)[0] |
239 | 238 | else:
|
240 |
| - return BaseModelOutputWithPast(torch.from_numpy(last_hidden_state).to(self.device), past_key_values=past_key_values) |
| 239 | + return BaseModelOutputWithPast( |
| 240 | + torch.from_numpy(last_hidden_state).to(self.device), past_key_values=past_key_values |
| 241 | + ) |
241 | 242 | else:
|
242 |
| - logits = self.request.get_tensor("logits" if self.lm_head_request is None else "last_hidden_state").data |
| 243 | + logits = self.request.get_tensor("logits").data |
243 | 244 | logits = torch.from_numpy(logits).to(self.device)
|
244 | 245 |
|
245 | 246 | return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
|
@@ -417,7 +418,7 @@ def __init__(
|
417 | 418 | )
|
418 | 419 | self.vision_embeddings = OVVisionEmbedding(self.vision_embeddings_model, self)
|
419 | 420 | for part in self.additional_parts:
|
420 |
| - if model_part == "lm_head": |
| 421 | + if part == "lm_head": |
421 | 422 | continue
|
422 | 423 | model_part = getattr(self, f"{part}_model", None)
|
423 | 424 | if model_part is not None:
|
@@ -909,7 +910,7 @@ def prepare_inputs_for_generation(
|
909 | 910 | "video_grid_thw": kwargs.get("video_grid_thw"),
|
910 | 911 | "token_type_ids": kwargs.get("token_type_ids"),
|
911 | 912 | "images_seq_mask": kwargs.get("images_seq_mask"),
|
912 |
| - "images_emb_mask": kwargs.get("images_emb_mask") |
| 913 | + "images_emb_mask": kwargs.get("images_emb_mask"), |
913 | 914 | }
|
914 | 915 | )
|
915 | 916 | return model_inputs
|
@@ -3309,7 +3310,6 @@ def preprocess_inputs(
|
3309 | 3310 | return processed_inputs
|
3310 | 3311 |
|
3311 | 3312 |
|
3312 |
| - |
3313 | 3313 | class _OVJanusForCausalLM(OVModelForVisualCausalLM):
|
3314 | 3314 | additional_parts = ["vision_gen_embeddings", "vision_gen_head", "vision_gen_decoder", "lm_head"]
|
3315 | 3315 |
|
@@ -3364,6 +3364,7 @@ def generate_image(
|
3364 | 3364 | img_size: int = 384,
|
3365 | 3365 | patch_size: int = 16,
|
3366 | 3366 | generator=None,
|
| 3367 | + show_progress=True, |
3367 | 3368 | ):
|
3368 | 3369 | from PIL import Image
|
3369 | 3370 |
|
@@ -3483,5 +3484,5 @@ def preprocess_inputs(
|
3483 | 3484 | "qwen2_5_vl": _OVQwen2_5_VLForCausalLM,
|
3484 | 3485 | "got_ocr2": _OVGotOCR2ForCausalLM,
|
3485 | 3486 | "gemma3": _OVGemma3ForCausalLM,
|
3486 |
| - "multi_modality": _OVJanusForCausalLM |
| 3487 | + "multi_modality": _OVJanusForCausalLM, |
3487 | 3488 | }
|
0 commit comments