|
14 | 14 |
|
15 | 15 | import logging as log
|
16 | 16 | import types
|
17 |
| -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union |
| 17 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union |
18 | 18 |
|
19 | 19 | import torch
|
20 | 20 | import torch.nn.functional as F
|
@@ -279,3 +279,201 @@ def __enter__(self):
|
279 | 279 | layer.self_attn.rotary_emb.inv_freq = 1.0 / (
|
280 | 280 | rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
|
281 | 281 | )
|
| 282 | + |
| 283 | + |
| 284 | +SUPPORT_SDPA = is_torch_version(">", "2.1.0") |
| 285 | + |
| 286 | + |
| 287 | +def _qwen_rotate_half(x): |
| 288 | + from einops import rearrange |
| 289 | + |
| 290 | + x = rearrange(x, "... (j d) -> ... j d", j=2) |
| 291 | + x1, x2 = x.unbind(dim=-2) |
| 292 | + return torch.cat((-x2, x1), dim=-1) |
| 293 | + |
| 294 | + |
| 295 | +def _qwen_apply_rotary_pos_emb(t, freqs): |
| 296 | + cos, sin = freqs |
| 297 | + rot_dim = freqs[0].shape[-1] |
| 298 | + cos, sin = freqs |
| 299 | + t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] |
| 300 | + t_ = t_.float() |
| 301 | + t_pass_ = t_pass_.float() |
| 302 | + t_ = (t_ * cos) + (_qwen_rotate_half(t_) * sin) |
| 303 | + return torch.cat((t_, t_pass_), dim=-1).type_as(t) |
| 304 | + |
| 305 | + |
| 306 | +def _qwen_quantize_cache_v(fdata, bits, qmax, qmin): |
| 307 | + # b, s, head, h-dim->b, head, s, h-dim |
| 308 | + qtype = torch.uint8 |
| 309 | + device = fdata.device |
| 310 | + shape = fdata.shape |
| 311 | + |
| 312 | + fdata_cal = torch.flatten(fdata, 2) |
| 313 | + fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) |
| 314 | + fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) |
| 315 | + # Compute params |
| 316 | + if qmax.device != fmax.device: |
| 317 | + qmax = qmax.to(device) |
| 318 | + qmin = qmin.to(device) |
| 319 | + scale = (fmax - fmin) / (qmax - qmin) |
| 320 | + zero = qmin - fmin / scale |
| 321 | + scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() |
| 322 | + zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() |
| 323 | + # Quantize |
| 324 | + res_data = fdata / scale + zero |
| 325 | + qdata = torch.clamp(res_data, qmin, qmax).to(qtype) |
| 326 | + return qdata.contiguous(), scale, zero |
| 327 | + |
| 328 | + |
| 329 | +def _qwen_attention_forward( |
| 330 | + self, |
| 331 | + hidden_states: Optional[Tuple[torch.FloatTensor]], |
| 332 | + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, |
| 333 | + layer_past: Optional[Tuple[torch.Tensor]] = None, |
| 334 | + attention_mask: Optional[torch.FloatTensor] = None, |
| 335 | + head_mask: Optional[torch.FloatTensor] = None, |
| 336 | + encoder_hidden_states: Optional[torch.Tensor] = None, |
| 337 | + encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| 338 | + output_attentions: Optional[bool] = False, |
| 339 | + use_cache: Optional[bool] = False, |
| 340 | +): |
| 341 | + mixed_x_layer = self.c_attn(hidden_states) |
| 342 | + |
| 343 | + query, key, value = mixed_x_layer.split(self.split_size, dim=2) |
| 344 | + |
| 345 | + query = self._split_heads(query, self.num_heads, self.head_dim) |
| 346 | + key = self._split_heads(key, self.num_heads, self.head_dim) |
| 347 | + value = self._split_heads(value, self.num_heads, self.head_dim) |
| 348 | + |
| 349 | + if rotary_pos_emb_list is not None: |
| 350 | + cur_len = query.shape[1] |
| 351 | + if len(rotary_pos_emb_list) == 1: |
| 352 | + rotary_pos_emb = rotary_pos_emb_list[0] |
| 353 | + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] |
| 354 | + rotary_pos_emb = (rotary_pos_emb,) * 2 |
| 355 | + q_pos_emb, k_pos_emb = rotary_pos_emb |
| 356 | + # Slice the pos emb for current inference |
| 357 | + query = _qwen_apply_rotary_pos_emb(query, q_pos_emb) |
| 358 | + key = _qwen_apply_rotary_pos_emb(key, k_pos_emb) |
| 359 | + else: |
| 360 | + query_list = [] |
| 361 | + key_list = [] |
| 362 | + for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): |
| 363 | + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] |
| 364 | + rotary_pos_emb = (rotary_pos_emb,) * 2 |
| 365 | + q_pos_emb, k_pos_emb = rotary_pos_emb |
| 366 | + # Slice the pos emb for current inference |
| 367 | + query_list += [_qwen_apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)] |
| 368 | + key_list += [_qwen_apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)] |
| 369 | + query = torch.cat(query_list, dim=0) |
| 370 | + key = torch.cat(key_list, dim=0) |
| 371 | + |
| 372 | + if self.use_cache_quantization: |
| 373 | + key = _qwen_quantize_cache_v(key.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax) |
| 374 | + value = _qwen_quantize_cache_v(value.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax) |
| 375 | + |
| 376 | + if layer_past is not None: |
| 377 | + past_key, past_value = layer_past[0], layer_past[1] |
| 378 | + if self.use_cache_quantization: |
| 379 | + # use_cache_quantization: |
| 380 | + # present=((q_key,key_scale,key_zero_point), |
| 381 | + # (q_value,value_scale,value_zero_point)) |
| 382 | + key = ( |
| 383 | + torch.cat((past_key[0], key[0]), dim=2), |
| 384 | + torch.cat((past_key[1], key[1]), dim=2), |
| 385 | + torch.cat((past_key[2], key[2]), dim=2), |
| 386 | + ) |
| 387 | + value = ( |
| 388 | + torch.cat((past_value[0], value[0]), dim=2), |
| 389 | + torch.cat((past_value[1], value[1]), dim=2), |
| 390 | + torch.cat((past_value[2], value[2]), dim=2), |
| 391 | + ) |
| 392 | + else: |
| 393 | + # not use_cache_quantization: |
| 394 | + # present=(key,value) |
| 395 | + key = torch.cat((past_key, key), dim=1) |
| 396 | + value = torch.cat((past_value, value), dim=1) |
| 397 | + |
| 398 | + if use_cache: |
| 399 | + present = (key, value) |
| 400 | + else: |
| 401 | + present = None |
| 402 | + |
| 403 | + if self.use_logn_attn and not self.training: |
| 404 | + if self.use_cache_quantization: |
| 405 | + seq_start = key[0].size(2) - query.size(1) |
| 406 | + seq_end = key[0].size(2) |
| 407 | + else: |
| 408 | + seq_start = key.size(1) - query.size(1) |
| 409 | + seq_end = key.size(1) |
| 410 | + logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) |
| 411 | + query = query * logn_tensor.expand_as(query) |
| 412 | + |
| 413 | + if self.use_flash_attn and not self.is_fp32 and query.is_cuda: |
| 414 | + q, k, v = query, key, value |
| 415 | + attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) |
| 416 | + else: |
| 417 | + registered_causal_mask = torch.tril( |
| 418 | + torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device) |
| 419 | + ).view(1, 1, key.size(1), key.size(1)) |
| 420 | + query = query.permute(0, 2, 1, 3) |
| 421 | + if not self.use_cache_quantization: |
| 422 | + key = key.permute(0, 2, 1, 3) |
| 423 | + value = value.permute(0, 2, 1, 3) |
| 424 | + |
| 425 | + if not self.use_cache_quantization and SUPPORT_SDPA: |
| 426 | + causal_mask = registered_causal_mask[:, :, key.size(-2) - query.size(-2) : key.size(-2), : key.size(-2)] |
| 427 | + if attention_mask is not None: |
| 428 | + attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1).masked_fill( |
| 429 | + ~causal_mask, torch.finfo(query.dtype).min |
| 430 | + ) |
| 431 | + else: |
| 432 | + attention_mask = causal_mask |
| 433 | + attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2) |
| 434 | + attn_weight = None |
| 435 | + else: |
| 436 | + attn_output, attn_weight = self._attn(query, key, value, registered_causal_mask, attention_mask, head_mask) |
| 437 | + context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) |
| 438 | + |
| 439 | + attn_output = self.c_proj(context_layer) |
| 440 | + |
| 441 | + outputs = (attn_output, present) |
| 442 | + if output_attentions: |
| 443 | + if self.use_flash_attn and not self.is_fp32: |
| 444 | + raise ValueError("Cannot output attentions while using flash-attn") |
| 445 | + else: |
| 446 | + outputs += (attn_weight,) |
| 447 | + |
| 448 | + return outputs |
| 449 | + |
| 450 | + |
| 451 | +class QwenModelPatcher(DecoderModelPatcher): |
| 452 | + def __init__( |
| 453 | + self, |
| 454 | + config: "OnnxConfig", |
| 455 | + model: Union["PreTrainedModel", "TFPreTrainedModel"], |
| 456 | + model_kwargs: Dict[str, Any], |
| 457 | + ): |
| 458 | + super().__init__(config, model, model_kwargs) |
| 459 | + |
| 460 | + self.original_fp16 = model.config.fp16 |
| 461 | + self.original_bf16 = model.config.bf16 |
| 462 | + model.config.bf16 = False |
| 463 | + model.config.fp16 = False |
| 464 | + if self.original_fp16 or self.original_bf16: |
| 465 | + model.to(torch.float32) |
| 466 | + model.transformer.rotary_emb(2048) |
| 467 | + |
| 468 | + def __enter__(self): |
| 469 | + super().__enter__() |
| 470 | + for block in self._model.transformer.h: |
| 471 | + block.attn._orig_forward = block.attn.forward |
| 472 | + block.attn.forward = types.MethodType(_qwen_attention_forward, block.attn) |
| 473 | + |
| 474 | + def __exit__(self, exc_type, exc_value, traceback): |
| 475 | + super().__exit__(exc_type, exc_value, traceback) |
| 476 | + for block in self._model.transformer.h: |
| 477 | + block.attn.forward = block.attn._orig_forward |
| 478 | + self._model.config.bf16 = self.original_bf16 |
| 479 | + self._model.config.fp16 = self.original_fp16 |
0 commit comments