Skip to content

Commit 4854ce1

Browse files
committed
qwen
1 parent 53686cf commit 4854ce1

File tree

4 files changed

+323
-3
lines changed

4 files changed

+323
-3
lines changed

optimum/exporters/openvino/model_configs.py

+119-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
2222
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig
23-
from optimum.exporters.openvino.model_patcher import ChatGLMModelPatcher, GemmaModelPatcher, MixtralModelPatcher
2423
from optimum.exporters.tasks import TasksManager
2524
from optimum.utils import DEFAULT_DUMMY_SHAPES
2625
from optimum.utils.input_generators import (
@@ -31,6 +30,8 @@
3130
)
3231
from optimum.utils.normalized_config import NormalizedTextConfig
3332

33+
from .model_patcher import ChatGLMModelPatcher, GemmaModelPatcher, MixtralModelPatcher, QwenModelPatcher
34+
3435

3536
def init_model_configs():
3637
supported_model_types = [
@@ -268,3 +269,120 @@ def patch_model_for_export(
268269
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
269270
) -> "ModelPatcher":
270271
return GemmaModelPatcher(self, model, model_kwargs=model_kwargs)
272+
273+
274+
class QwenDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
275+
def __init__(
276+
self,
277+
task: str,
278+
normalized_config: NormalizedTextConfig,
279+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
280+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
281+
random_batch_size_range: Optional[Tuple[int, int]] = None,
282+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
283+
**kwargs,
284+
):
285+
super().__init__(
286+
task=task,
287+
normalized_config=normalized_config,
288+
batch_size=batch_size,
289+
sequence_length=sequence_length,
290+
random_batch_size_range=random_batch_size_range,
291+
random_sequence_length_range=random_sequence_length_range,
292+
)
293+
self.kv_channels = normalized_config.kv_channels
294+
295+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
296+
past_key_shape = (self.batch_size, self.sequence_length, self.num_attention_heads, self.kv_channels)
297+
past_value_shape = (self.batch_size, self.sequence_length, self.num_attention_heads, self.kv_channels)
298+
return [
299+
(
300+
self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype),
301+
self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype),
302+
)
303+
for _ in range(self.num_layers)
304+
]
305+
306+
307+
@register_in_tasks_manager("qwen", *["text-generation", "text-generation-with-past"])
308+
class QwenOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
309+
DEFAULT_ONNX_OPSET = 14
310+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
311+
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size"
312+
)
313+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, QwenDummyPastKeyValuesGenerator)
314+
DUMMY_PKV_GENERATOR_CLASS = QwenDummyPastKeyValuesGenerator
315+
no_position_ids = False
316+
317+
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
318+
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)
319+
320+
dummy_inputs = {}
321+
input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")]
322+
if self.use_past_in_inputs and self.use_cache_branch is not False:
323+
input_names.append("past_key_values")
324+
325+
for input_name in input_names:
326+
input_was_inserted = False
327+
for dummy_input_gen in dummy_inputs_generators:
328+
if dummy_input_gen.supports_input(input_name):
329+
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
330+
dummy_input_gen,
331+
input_name,
332+
framework,
333+
input_shapes=kwargs,
334+
)
335+
input_was_inserted = True
336+
break
337+
if not input_was_inserted:
338+
raise RuntimeError(
339+
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
340+
)
341+
342+
# refer to https://github.com/huggingface/optimum/pull/764
343+
if (
344+
self.use_past_in_inputs
345+
and self.PAD_ATTENTION_MASK_TO_PAST
346+
and self.use_cache_branch is not False
347+
and "attention_mask" in dummy_inputs
348+
):
349+
# Obtain the past sequence length from the value instead of the key (Bloom). Qwen has seq_len in 1 dim instead of -2
350+
past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[1]
351+
352+
dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
353+
dummy_inputs["attention_mask"],
354+
desired_length=past_present_length,
355+
dim=1,
356+
dtype=dummy_inputs["attention_mask"].dtype,
357+
)
358+
359+
return dummy_inputs
360+
361+
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
362+
"""
363+
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
364+
365+
Args:
366+
inputs_or_outputs (`Dict[str, Dict[int, str]]`): The mapping to fill.
367+
direction (`str`):
368+
either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
369+
output mapping, this is important for axes naming.
370+
"""
371+
if direction not in ["inputs", "outputs"]:
372+
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
373+
374+
if direction == "inputs":
375+
decoder_sequence_name = "past_sequence_length"
376+
name = "past_key_values"
377+
else:
378+
decoder_sequence_name = "past_sequence_length + 1"
379+
name = "present"
380+
381+
for i in range(self._normalized_config.num_layers):
382+
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 1: decoder_sequence_name}
383+
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 1: decoder_sequence_name}
384+
385+
def patch_model_for_export(
386+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
387+
) -> "ModelPatcher":
388+
return QwenModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+199-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging as log
1616
import types
17-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
17+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1818

1919
import torch
2020
import torch.nn.functional as F
@@ -279,3 +279,201 @@ def __enter__(self):
279279
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
280280
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
281281
)
282+
283+
284+
SUPPORT_SDPA = is_torch_version(">", "2.1.0")
285+
286+
287+
def _qwen_rotate_half(x):
288+
from einops import rearrange
289+
290+
x = rearrange(x, "... (j d) -> ... j d", j=2)
291+
x1, x2 = x.unbind(dim=-2)
292+
return torch.cat((-x2, x1), dim=-1)
293+
294+
295+
def _qwen_apply_rotary_pos_emb(t, freqs):
296+
cos, sin = freqs
297+
rot_dim = freqs[0].shape[-1]
298+
cos, sin = freqs
299+
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
300+
t_ = t_.float()
301+
t_pass_ = t_pass_.float()
302+
t_ = (t_ * cos) + (_qwen_rotate_half(t_) * sin)
303+
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
304+
305+
306+
def _qwen_quantize_cache_v(fdata, bits, qmax, qmin):
307+
# b, s, head, h-dim->b, head, s, h-dim
308+
qtype = torch.uint8
309+
device = fdata.device
310+
shape = fdata.shape
311+
312+
fdata_cal = torch.flatten(fdata, 2)
313+
fmax = torch.amax(fdata_cal, dim=-1, keepdim=True)
314+
fmin = torch.amin(fdata_cal, dim=-1, keepdim=True)
315+
# Compute params
316+
if qmax.device != fmax.device:
317+
qmax = qmax.to(device)
318+
qmin = qmin.to(device)
319+
scale = (fmax - fmin) / (qmax - qmin)
320+
zero = qmin - fmin / scale
321+
scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
322+
zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
323+
# Quantize
324+
res_data = fdata / scale + zero
325+
qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
326+
return qdata.contiguous(), scale, zero
327+
328+
329+
def _qwen_attention_forward(
330+
self,
331+
hidden_states: Optional[Tuple[torch.FloatTensor]],
332+
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
333+
layer_past: Optional[Tuple[torch.Tensor]] = None,
334+
attention_mask: Optional[torch.FloatTensor] = None,
335+
head_mask: Optional[torch.FloatTensor] = None,
336+
encoder_hidden_states: Optional[torch.Tensor] = None,
337+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
338+
output_attentions: Optional[bool] = False,
339+
use_cache: Optional[bool] = False,
340+
):
341+
mixed_x_layer = self.c_attn(hidden_states)
342+
343+
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
344+
345+
query = self._split_heads(query, self.num_heads, self.head_dim)
346+
key = self._split_heads(key, self.num_heads, self.head_dim)
347+
value = self._split_heads(value, self.num_heads, self.head_dim)
348+
349+
if rotary_pos_emb_list is not None:
350+
cur_len = query.shape[1]
351+
if len(rotary_pos_emb_list) == 1:
352+
rotary_pos_emb = rotary_pos_emb_list[0]
353+
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
354+
rotary_pos_emb = (rotary_pos_emb,) * 2
355+
q_pos_emb, k_pos_emb = rotary_pos_emb
356+
# Slice the pos emb for current inference
357+
query = _qwen_apply_rotary_pos_emb(query, q_pos_emb)
358+
key = _qwen_apply_rotary_pos_emb(key, k_pos_emb)
359+
else:
360+
query_list = []
361+
key_list = []
362+
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
363+
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
364+
rotary_pos_emb = (rotary_pos_emb,) * 2
365+
q_pos_emb, k_pos_emb = rotary_pos_emb
366+
# Slice the pos emb for current inference
367+
query_list += [_qwen_apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)]
368+
key_list += [_qwen_apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
369+
query = torch.cat(query_list, dim=0)
370+
key = torch.cat(key_list, dim=0)
371+
372+
if self.use_cache_quantization:
373+
key = _qwen_quantize_cache_v(key.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax)
374+
value = _qwen_quantize_cache_v(value.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax)
375+
376+
if layer_past is not None:
377+
past_key, past_value = layer_past[0], layer_past[1]
378+
if self.use_cache_quantization:
379+
# use_cache_quantization:
380+
# present=((q_key,key_scale,key_zero_point),
381+
# (q_value,value_scale,value_zero_point))
382+
key = (
383+
torch.cat((past_key[0], key[0]), dim=2),
384+
torch.cat((past_key[1], key[1]), dim=2),
385+
torch.cat((past_key[2], key[2]), dim=2),
386+
)
387+
value = (
388+
torch.cat((past_value[0], value[0]), dim=2),
389+
torch.cat((past_value[1], value[1]), dim=2),
390+
torch.cat((past_value[2], value[2]), dim=2),
391+
)
392+
else:
393+
# not use_cache_quantization:
394+
# present=(key,value)
395+
key = torch.cat((past_key, key), dim=1)
396+
value = torch.cat((past_value, value), dim=1)
397+
398+
if use_cache:
399+
present = (key, value)
400+
else:
401+
present = None
402+
403+
if self.use_logn_attn and not self.training:
404+
if self.use_cache_quantization:
405+
seq_start = key[0].size(2) - query.size(1)
406+
seq_end = key[0].size(2)
407+
else:
408+
seq_start = key.size(1) - query.size(1)
409+
seq_end = key.size(1)
410+
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
411+
query = query * logn_tensor.expand_as(query)
412+
413+
if self.use_flash_attn and not self.is_fp32 and query.is_cuda:
414+
q, k, v = query, key, value
415+
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
416+
else:
417+
registered_causal_mask = torch.tril(
418+
torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
419+
).view(1, 1, key.size(1), key.size(1))
420+
query = query.permute(0, 2, 1, 3)
421+
if not self.use_cache_quantization:
422+
key = key.permute(0, 2, 1, 3)
423+
value = value.permute(0, 2, 1, 3)
424+
425+
if not self.use_cache_quantization and SUPPORT_SDPA:
426+
causal_mask = registered_causal_mask[:, :, key.size(-2) - query.size(-2) : key.size(-2), : key.size(-2)]
427+
if attention_mask is not None:
428+
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1).masked_fill(
429+
~causal_mask, torch.finfo(query.dtype).min
430+
)
431+
else:
432+
attention_mask = causal_mask
433+
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
434+
attn_weight = None
435+
else:
436+
attn_output, attn_weight = self._attn(query, key, value, registered_causal_mask, attention_mask, head_mask)
437+
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
438+
439+
attn_output = self.c_proj(context_layer)
440+
441+
outputs = (attn_output, present)
442+
if output_attentions:
443+
if self.use_flash_attn and not self.is_fp32:
444+
raise ValueError("Cannot output attentions while using flash-attn")
445+
else:
446+
outputs += (attn_weight,)
447+
448+
return outputs
449+
450+
451+
class QwenModelPatcher(DecoderModelPatcher):
452+
def __init__(
453+
self,
454+
config: "OnnxConfig",
455+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
456+
model_kwargs: Dict[str, Any],
457+
):
458+
super().__init__(config, model, model_kwargs)
459+
460+
self.original_fp16 = model.config.fp16
461+
self.original_bf16 = model.config.bf16
462+
model.config.bf16 = False
463+
model.config.fp16 = False
464+
if self.original_fp16 or self.original_bf16:
465+
model.to(torch.float32)
466+
model.transformer.rotary_emb(2048)
467+
468+
def __enter__(self):
469+
super().__enter__()
470+
for block in self._model.transformer.h:
471+
block.attn._orig_forward = block.attn.forward
472+
block.attn.forward = types.MethodType(_qwen_attention_forward, block.attn)
473+
474+
def __exit__(self, exc_type, exc_value, traceback):
475+
super().__exit__(exc_type, exc_value, traceback)
476+
for block in self._model.transformer.h:
477+
block.attn.forward = block.attn._orig_forward
478+
self._model.config.bf16 = self.original_bf16
479+
self._model.config.fp16 = self.original_fp16

tests/openvino/test_modeling.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -495,12 +495,13 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
495495
"mpt",
496496
"opt",
497497
"pegasus",
498+
"qwen",
498499
"qwen2",
499500
"stablelm",
500501
)
501502
GENERATION_LENGTH = 100
502503
IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3")
503-
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais")
504+
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen")
504505

505506
@parameterized.expand(SUPPORTED_ARCHITECTURES)
506507
def test_compare_to_transformers(self, model_arch):
@@ -531,6 +532,8 @@ def test_compare_to_transformers(self, model_arch):
531532
)
532533
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
533534
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
535+
if model_arch == "qwen":
536+
transformers_model.to(torch.float32)
534537
tokens = tokenizer(
535538
"This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
536539
)

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
"pegasus": "hf-internal-testing/tiny-random-pegasus",
7373
"pix2struct": "fxmarty/pix2struct-tiny-random",
7474
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
75+
"qwen": "katuni4ka/tiny-random-qwen",
7576
"qwen2": "Qwen/Qwen1.5-0.5B",
7677
"resnet": "hf-internal-testing/tiny-random-resnet",
7778
"roberta": "hf-internal-testing/tiny-random-roberta",

0 commit comments

Comments
 (0)