diff --git a/docs/source/openvino/models.mdx b/docs/source/openvino/models.mdx index e1c8c7864e..7b696e617f 100644 --- a/docs/source/openvino/models.mdx +++ b/docs/source/openvino/models.mdx @@ -43,6 +43,9 @@ Here is the list of the supported architectures : - Deberta-v2 - DeciLM - Deit +- Deepseek +- Deepseek_v2 +- Deepseek_v3 - DistilBert - Electra - Encoder Decoder diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 6a70c3b5ad..12561408db 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -47,7 +47,6 @@ _transformers_version, compare_versions, is_openvino_tokenizers_version, - is_openvino_version, is_tokenizers_version, is_transformers_version, ) @@ -67,6 +66,7 @@ OV_XML_FILE_NAME, _get_input_info, _get_open_clip_submodels_fn_and_export_configs, + allow_skip_tracing_check, clear_class_registry, remove_none_from_dummy_inputs, save_config, @@ -437,7 +437,9 @@ def ts_patched_forward(*args, **kwargs): patcher.patched_forward = ts_patched_forward ts_decoder_kwargs = {} - if library_name == "diffusers" and is_openvino_version(">=", "2025.0"): + model_config = getattr(model, "config", {}) + model_type = getattr(model_config, "model_type", "").replace("_", "-") + if allow_skip_tracing_check(library_name, model_type): ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False} with patcher: diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 6807644b9e..81263c15e0 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -75,6 +75,7 @@ CodeGenModelPatcher, DBRXModelPatcher, DeciLMModelPatcher, + DeepseekPatcher, FalconModelPatcher, FluxTransfromerModelPatcher, Gemma2ModelPatcher, @@ -2782,3 +2783,17 @@ class MT5OpenVINOConfig(T5OpenVINOConfig): ) class LongT5OpenVINOConfig(T5OpenVINOConfig): pass + + +@register_in_tasks_manager( + "deepseek-v3", *["text-generation", "text-generation-with-past"], library_name="transformers" +) +@register_in_tasks_manager( + "deepseek-v2", *["text-generation", "text-generation-with-past"], library_name="transformers" +) +@register_in_tasks_manager("deepseek", *["text-generation", "text-generation-with-past"], library_name="transformers") +class DeepseekOpenVINOConfig(MiniCPM3OpenVINOConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return DeepseekPatcher(self, model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index b524f91485..87154f7ff9 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -3575,6 +3575,301 @@ def __exit__(self, exc_type, exc_value, traceback): block.self_attn.forward = block.self_attn._orig_forward +class DeepseekPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + self_attn = { + "deepseek_v3": deepseek_v3_attn_forward, + "deepseek_v2": deepseek_v2_attn_forward, + "deepseek": minicpm3_attn_forward, + } + + self_attn_fwd = self_attn.get(self._model.config.model_type) + for block in self._model.model.layers: + if self_attn_fwd is not None: + block.self_attn._orig_forward = block.self_attn.forward + block.self_attn.forward = types.MethodType(self_attn_fwd, block.self_attn) + if hasattr(block.mlp, "moe_infer"): + block.mlp._org_moe_infer = block.mlp.moe_infer + block.mlp.moe_infer = types.MethodType(deepseek_moe_infer, block.mlp) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for block in self._model.model.layers: + block.self_attn.forward = block.self_attn._orig_forward + if hasattr(block.mlp, "_orig_moe_infer"): + block.mlp.moe_infer = block.mlp._orig_moe_infer + + +def deepseek_v3_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value=None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # modified from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L751 + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + orig_dtype = k.dtype + cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + q_fp32 = q.to(dtype=torch.float32, device=q.device) + k_fp32 = k.to(dtype=torch.float32, device=k.device) + q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin) + k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin) + return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype) + + if output_attentions: + return self._orig_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + # Difference with original code, k_pe.new_empty create constant tensor in torchscript + query_states = torch.concat([q_nope, q_pe], dim=-1) + # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1) + # key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def deepseek_v2_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value=None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # modified from https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py#L806 + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + if output_attentions: + return self._orig_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.shape + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + # Difference with original code, k_pe.new_empty create constant tensor in torchscript + query_states = torch.concat([q_nope, q_pe], dim=-1) + # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1) + # key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def deepseek_moe_infer(self, x, topk_ids, topk_weight): + cnts = torch.zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0).to(torch.long) + idxs = torch.argsort(topk_ids.view(-1)) + sorted_tokens = x[idxs // topk_ids.shape[1]] + + outputs = [] + start_idx = torch.tensor(0, dtype=torch.long) + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + # difference with original: removed skiping expert if empty num_tokens + expert_id = i + self.ep_rank * self.experts_per_rank + expert = self.experts[expert_id] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + # difference with original: removed usage torch.new_empty if outputs empty + outs = torch.cat(outputs, dim=0) + + new_x = torch.zeros_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .to(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .to(new_x.dtype) + ) + return final_out + + class Qwen2VLLanguageModelPatcher(DecoderModelPatcher): def __init__( self, diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 1ab9e1051e..6eef614799 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -26,7 +26,7 @@ from optimum.exporters import TasksManager from optimum.exporters.onnx.base import OnnxConfig from optimum.intel.utils import is_transformers_version -from optimum.intel.utils.import_utils import is_safetensors_available +from optimum.intel.utils.import_utils import is_openvino_version, is_safetensors_available from optimum.utils import is_diffusers_available from optimum.utils.save_utils import maybe_save_preprocessors @@ -344,3 +344,14 @@ def set_simplified_chat_template(ov_tokenizer_model, processor_chat_template=Non if tokenizer_chat_template is not None and tokenizer_chat_template in COMPLEX_CHAT_TEMPLATES: ov_tokenizer_model.set_rt_info(COMPLEX_CHAT_TEMPLATES[tokenizer_chat_template], "simplified_chat_template") return ov_tokenizer_model + + +SKIP_CHECK_TRACE_MODELS = ("deepseek", "deepseek-v2", "deepseek-v3") + + +def allow_skip_tracing_check(library_name, model_type): + if is_openvino_version("<", "2025.0.0"): + return False + if library_name == "diffusers": + return True + return model_type in SKIP_CHECK_TRACE_MODELS diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 60b7576973..8122ed3de1 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1001,6 +1001,9 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): if is_transformers_version(">=", "4.46.0"): SUPPORTED_ARCHITECTURES += ("glm", "mistral-nemo", "minicpm3") + # openvino 2025.0 required for disabling check_trace + if is_openvino_version(">=", "2025.0"): + SUPPORTED_ARCHITECTURES += ("deepseek",) # gptq and awq install disabled for windows test environment if platform.system() != "Windows": @@ -1030,6 +1033,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "exaone", "decilm", "minicpm3", + "deepseek", ) @parameterized.expand(SUPPORTED_ARCHITECTURES) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index ac39b065ca..140f8f771e 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -51,6 +51,7 @@ "deberta": "hf-internal-testing/tiny-random-deberta", "deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model", "decilm": "katuni4ka/tiny-random-decilm", + "deepseek": "katuni4ka/tiny-random-deepseek-v3", "deit": "hf-internal-testing/tiny-random-DeiTModel", "convnext": "hf-internal-testing/tiny-random-convnext", "convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model",