Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support deeepseek models export #1155

Merged
merged 12 commits into from
Feb 24, 2025
3 changes: 3 additions & 0 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ Here is the list of the supported architectures :
- Deberta-v2
- DeciLM
- Deit
- Deepseek
- Deepseek_v2
- Deepseek_v3
- DistilBert
- Electra
- Encoder Decoder
Expand Down
6 changes: 4 additions & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
_transformers_version,
compare_versions,
is_openvino_tokenizers_version,
is_openvino_version,
is_tokenizers_version,
is_transformers_version,
)
Expand All @@ -67,6 +66,7 @@
OV_XML_FILE_NAME,
_get_input_info,
_get_open_clip_submodels_fn_and_export_configs,
allow_skip_tracing_check,
clear_class_registry,
remove_none_from_dummy_inputs,
save_config,
Expand Down Expand Up @@ -437,7 +437,9 @@ def ts_patched_forward(*args, **kwargs):
patcher.patched_forward = ts_patched_forward

ts_decoder_kwargs = {}
if library_name == "diffusers" and is_openvino_version(">=", "2025.0"):
model_config = getattr(model, "config", {})
model_type = getattr(model_config, "model_type", "").replace("_", "-")
if allow_skip_tracing_check(library_name, model_type):
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}

with patcher:
Expand Down
15 changes: 15 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
CodeGenModelPatcher,
DBRXModelPatcher,
DeciLMModelPatcher,
DeepseekPatcher,
FalconModelPatcher,
FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
Expand Down Expand Up @@ -2782,3 +2783,17 @@ class MT5OpenVINOConfig(T5OpenVINOConfig):
)
class LongT5OpenVINOConfig(T5OpenVINOConfig):
pass


@register_in_tasks_manager(
"deepseek-v3", *["text-generation", "text-generation-with-past"], library_name="transformers"
)
@register_in_tasks_manager(
"deepseek-v2", *["text-generation", "text-generation-with-past"], library_name="transformers"
)
@register_in_tasks_manager("deepseek", *["text-generation", "text-generation-with-past"], library_name="transformers")
class DeepseekOpenVINOConfig(MiniCPM3OpenVINOConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return DeepseekPatcher(self, model, model_kwargs=model_kwargs)
295 changes: 295 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3575,6 +3575,301 @@ def __exit__(self, exc_type, exc_value, traceback):
block.self_attn.forward = block.self_attn._orig_forward


class DeepseekPatcher(DecoderModelPatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepseek-v3 will be included in transformers v4.50 (huggingface/transformers#35926) could make sense to make sure everything is compatible before merging

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pr stated "code relies heavily on original remote code.", it is exact the same reference code that I used

def __enter__(self):
super().__enter__()
self_attn = {
"deepseek_v3": deepseek_v3_attn_forward,
"deepseek_v2": deepseek_v2_attn_forward,
"deepseek": minicpm3_attn_forward,
}

self_attn_fwd = self_attn.get(self._model.config.model_type)
for block in self._model.model.layers:
if self_attn_fwd is not None:
block.self_attn._orig_forward = block.self_attn.forward
block.self_attn.forward = types.MethodType(self_attn_fwd, block.self_attn)
if hasattr(block.mlp, "moe_infer"):
block.mlp._org_moe_infer = block.mlp.moe_infer
block.mlp.moe_infer = types.MethodType(deepseek_moe_infer, block.mlp)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.model.layers:
block.self_attn.forward = block.self_attn._orig_forward
if hasattr(block.mlp, "_orig_moe_infer"):
block.mlp.moe_infer = block.mlp._orig_moe_infer


def deepseek_v3_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# modified from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L751
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
orig_dtype = k.dtype
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
q_fp32 = q.to(dtype=torch.float32, device=q.device)
k_fp32 = k.to(dtype=torch.float32, device=k.device)
q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)

if output_attentions:
return self._orig_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)

bsz, q_len, _ = hidden_states.size()

if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)

k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

# Difference with original code, k_pe.new_empty create constant tensor in torchscript
query_states = torch.concat([q_nope, q_pe], dim=-1)
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1)
# key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)

attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value


def deepseek_v2_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# modified from https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py#L806
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)

b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

if output_attentions:
return self._orig_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)

bsz, q_len, _ = hidden_states.shape

if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)

k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

# Difference with original code, k_pe.new_empty create constant tensor in torchscript
query_states = torch.concat([q_nope, q_pe], dim=-1)
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1)
# key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)

attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value


def deepseek_moe_infer(self, x, topk_ids, topk_weight):
cnts = torch.zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0).to(torch.long)
idxs = torch.argsort(topk_ids.view(-1))
sorted_tokens = x[idxs // topk_ids.shape[1]]

outputs = []
start_idx = torch.tensor(0, dtype=torch.long)
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
# difference with original: removed skiping expert if empty num_tokens
expert_id = i + self.ep_rank * self.experts_per_rank
expert = self.experts[expert_id]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx

# difference with original: removed usage torch.new_empty if outputs empty
outs = torch.cat(outputs, dim=0)

new_x = torch.zeros_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.to(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.to(new_x.dtype)
)
return final_out


class Qwen2VLLanguageModelPatcher(DecoderModelPatcher):
def __init__(
self,
Expand Down
13 changes: 12 additions & 1 deletion optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from optimum.exporters import TasksManager
from optimum.exporters.onnx.base import OnnxConfig
from optimum.intel.utils import is_transformers_version
from optimum.intel.utils.import_utils import is_safetensors_available
from optimum.intel.utils.import_utils import is_openvino_version, is_safetensors_available
from optimum.utils import is_diffusers_available
from optimum.utils.save_utils import maybe_save_preprocessors

Expand Down Expand Up @@ -344,3 +344,14 @@ def set_simplified_chat_template(ov_tokenizer_model, processor_chat_template=Non
if tokenizer_chat_template is not None and tokenizer_chat_template in COMPLEX_CHAT_TEMPLATES:
ov_tokenizer_model.set_rt_info(COMPLEX_CHAT_TEMPLATES[tokenizer_chat_template], "simplified_chat_template")
return ov_tokenizer_model


SKIP_CHECK_TRACE_MODELS = ("deepseek", "deepseek-v2", "deepseek-v3")


def allow_skip_tracing_check(library_name, model_type):
if is_openvino_version("<", "2025.0.0"):
return False
if library_name == "diffusers":
return True
return model_type in SKIP_CHECK_TRACE_MODELS
Loading
Loading