Skip to content

Commit ee02349

Browse files
committed
fix bauchan-13b
1 parent 514f054 commit ee02349

File tree

4 files changed

+102
-16
lines changed

4 files changed

+102
-16
lines changed

optimum/exporters/openvino/convert.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def ts_patched_forward(*args, **kwargs):
347347

348348
with patcher:
349349
check_dummy_inputs_are_allowed(model, dummy_inputs)
350+
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
350351
inputs = config.ordered_inputs(model)
351352
input_names = list(inputs.keys())
352353
output_names = list(config.outputs.keys())
@@ -376,7 +377,6 @@ def ts_patched_forward(*args, **kwargs):
376377
ov_config=ov_config,
377378
)
378379

379-
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
380380
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
381381
if not ordered_dummy_inputs:
382382
ordered_dummy_inputs = dummy_inputs
@@ -388,15 +388,16 @@ def ts_patched_forward(*args, **kwargs):
388388
out_tensor.get_tensor().set_names({output_names[idx]})
389389

390390
for idx, inp_tensor in enumerate(ov_model.inputs):
391-
input_name = ordered_input_names[idx]
392-
inp_tensor.get_tensor().set_names({input_name})
393-
inp_data = flatten_inputs[idx]
394-
static_shape = PartialShape(inp_data.shape)
395-
dims = inputs[input_name]
396-
for dim in dims:
397-
static_shape[dim] = -1
398-
inp_tensor.get_node().set_partial_shape(static_shape)
399-
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
391+
if idx < len(ordered_input_names):
392+
input_name = ordered_input_names[idx]
393+
inp_tensor.get_tensor().set_names({input_name})
394+
inp_data = flatten_inputs[idx]
395+
static_shape = PartialShape(inp_data.shape)
396+
dims = inputs.get(input_name, [])
397+
for dim in dims:
398+
static_shape[dim] = -1
399+
inp_tensor.get_node().set_partial_shape(static_shape)
400+
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
400401
ov_model.validate_nodes_and_infer_types()
401402

402403
if stateful:

optimum/exporters/openvino/model_configs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
BaichuanModelPatcher,
3535
ChatGLMModelPatcher,
3636
GemmaModelPatcher,
37+
InternLMPatcher,
3738
LlamaModelPatcher,
3839
MixtralModelPatcher,
39-
QwenModelPatcher,
4040
MPTModelPatcher,
41-
InternLMPatcher,
41+
QwenModelPatcher,
4242
)
4343

4444

optimum/exporters/openvino/model_patcher.py

+86-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
import logging as log
1617
import math
1718
import types
@@ -601,6 +602,46 @@ def __exit__(self, exc_type, exc_value, traceback):
601602
self._model.config.fp16 = self.original_fp16
602603

603604

605+
def _baichuan13b_atten_forward(
606+
self,
607+
hidden_states: torch.Tensor,
608+
attention_mask: Optional[torch.Tensor] = None,
609+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
610+
output_attentions: bool = False,
611+
use_cache: bool = True,
612+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
613+
bsz, q_len, _ = hidden_states.size()
614+
615+
proj = self.W_pack(hidden_states)
616+
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
617+
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
618+
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
619+
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
620+
621+
kv_seq_len = key_states.shape[-2]
622+
if past_key_value is not None:
623+
kv_seq_len += past_key_value[0].shape[-2]
624+
625+
if past_key_value is not None:
626+
# reuse k, v, self_attention
627+
if attention_mask is not None:
628+
attention_mask = attention_mask[:, :, -key_states.shape[-2] :, :]
629+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
630+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
631+
632+
past_key_value = (key_states, value_states) if use_cache else None
633+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
634+
attn_output = attn_output.transpose(1, 2)
635+
attn_weights = None
636+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
637+
attn_output = self.o_proj(attn_output)
638+
639+
if not output_attentions:
640+
attn_weights = None
641+
642+
return attn_output, attn_weights, past_key_value
643+
644+
604645
class BaichuanModelPatcher(DecoderModelPatcher):
605646
def __init__(
606647
self,
@@ -613,6 +654,50 @@ def __init__(
613654
if hasattr(self._model.lm_head, "first_flag"):
614655
self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))
615656

657+
def __enter__(self):
658+
super().__enter__()
659+
# override signature to have position_ids
660+
if "position_ids" not in inspect.signature(self._model.forward).parameters:
661+
self._model._orig_forward = self._model.forward
662+
663+
def forward(
664+
self,
665+
input_ids: torch.LongTensor = None,
666+
attention_mask: Optional[torch.Tensor] = None,
667+
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
668+
inputs_embeds: Optional[torch.FloatTensor] = None,
669+
labels: Optional[torch.LongTensor] = None,
670+
use_cache: Optional[bool] = None,
671+
output_attentions: Optional[bool] = False,
672+
output_hidden_states: Optional[bool] = False,
673+
return_dict: Optional[bool] = True,
674+
position_ids: Optional[torch.LongTensor] = None,
675+
):
676+
return self._orig_forward(
677+
input_ids=input_ids,
678+
attention_mask=attention_mask,
679+
past_key_values=past_key_values,
680+
inputs_embeds=inputs_embeds,
681+
labels=labels,
682+
use_cache=past_key_values is not None,
683+
output_attentions=output_attentions,
684+
output_hidden_states=output_hidden_states,
685+
return_dict=self.config.return_dict,
686+
)
687+
688+
self._model.forward = types.MethodType(forward, self._model)
689+
for layer in self._model.model.layers:
690+
layer.self_attn._orig_forward = layer.self_attn.forward
691+
layer.self_attn.forward = types.MethodType(_baichuan13b_atten_forward, layer.self_attn)
692+
693+
def __exit__(self, exc_type, exc_value, traceback):
694+
super().__exit__(exc_type, exc_value, traceback)
695+
if hasattr(self._model, "_orig_forward"):
696+
self._model.forward = self._model._orig_forward
697+
698+
for layer in self._model.model.layers:
699+
layer.self_attn.forward = layer.self_attn._orig_forward
700+
616701

617702
def _mpt_attention_forward(
618703
self,
@@ -680,7 +765,7 @@ def _internlm_attention_forward(
680765
**kwargs,
681766
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
682767

683-
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
768+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
684769

685770
bsz, q_len, _ = hidden_states.size()
686771

optimum/intel/openvino/quantization.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,9 @@ def _quantize_torchmodel(
484484
subset_size=quantization_config.num_samples,
485485
ignored_scope=quantization_config.get_ignored_scope_instance(),
486486
model_type=nncf.ModelType(quantization_config.model_type),
487-
preset=nncf.QuantizationPreset.PERFORMANCE
488-
if quantization_config.sym
489-
else nncf.QuantizationPreset.MIXED,
487+
preset=(
488+
nncf.QuantizationPreset.PERFORMANCE if quantization_config.sym else nncf.QuantizationPreset.MIXED
489+
),
490490
fast_bias_correction=quantization_config.fast_bias_correction,
491491
advanced_parameters=nncf.AdvancedQuantizationParameters(
492492
overflow_fix=OverflowFix(quantization_config.overflow_fix)

0 commit comments

Comments
 (0)