Skip to content

Commit bbdca54

Browse files
committed
chatglm export
1 parent c1064fd commit bbdca54

File tree

2 files changed

+289
-3
lines changed

2 files changed

+289
-3
lines changed

optimum/exporters/openvino/model_configs.py

+164-3
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
15+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
1616

1717
from packaging import version
1818
from transformers.utils import is_tf_available
1919

2020
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
2121
from optimum.exporters.tasks import TasksManager
22-
from optimum.utils.input_generators import DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator
22+
from optimum.utils import DEFAULT_DUMMY_SHAPES
23+
from optimum.utils.input_generators import (
24+
DummyInputGenerator,
25+
DummyPastKeyValuesGenerator,
26+
DummyTextInputGenerator,
27+
MistralDummyPastKeyValuesGenerator,
28+
)
2329
from optimum.utils.normalized_config import NormalizedTextConfig
2430

25-
from .model_patcher import MixtralModelPatcher
31+
from .model_patcher import ChatGLMModelPatcher, MixtralModelPatcher
2632

2733

2834
if TYPE_CHECKING:
@@ -70,6 +76,161 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
7076
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
7177

7278

79+
class ChatGLM2DummyTextInputGenerator(DummyTextInputGenerator):
80+
SUPPORTED_INPUT_NAMES = {
81+
"input_ids",
82+
"attention_mask",
83+
"token_type_ids",
84+
"position_ids",
85+
}
86+
87+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
88+
import torch
89+
90+
input = super().generate(input_name, framework, int_dtype, float_dtype)
91+
if input_name == "attention_mask":
92+
input = torch.ones(input.shape, dtype=input.dtype)
93+
if input_name == "position_ids":
94+
bs = input.shape[0]
95+
input = torch.range(0, input.shape[1], dtype=input.dtype).repeat(bs, 1)
96+
return input
97+
98+
99+
class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
100+
def __init__(
101+
self,
102+
task: str,
103+
normalized_config: NormalizedTextConfig,
104+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
105+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
106+
random_batch_size_range: Optional[Tuple[int, int]] = None,
107+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
108+
**kwargs,
109+
):
110+
super().__init__(
111+
task=task,
112+
normalized_config=normalized_config,
113+
batch_size=batch_size,
114+
sequence_length=sequence_length,
115+
random_batch_size_range=random_batch_size_range,
116+
random_sequence_length_range=random_sequence_length_range,
117+
)
118+
self.multi_query_group_num = normalized_config.multi_query_group_num
119+
self.head_dim = self.hidden_size // self.num_attention_heads
120+
121+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
122+
past_key_shape = (
123+
self.sequence_length,
124+
self.batch_size,
125+
self.multi_query_group_num,
126+
self.head_dim,
127+
)
128+
past_value_shape = (
129+
self.sequence_length,
130+
self.batch_size,
131+
self.multi_query_group_num,
132+
self.head_dim,
133+
)
134+
return [
135+
(
136+
self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype),
137+
self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype),
138+
)
139+
for _ in range(self.num_layers)
140+
]
141+
142+
143+
@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"])
144+
class ChatGLM2OpenVINOConfig(TextDecoderOnnxConfig):
145+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(vocab_size="padded_vocab_size", num_layers="num_layers")
146+
DUMMY_INPUT_GENERATOR_CLASSES = (ChatGLM2DummyTextInputGenerator, ChatGLM2DummyPastKeyValuesGenerator)
147+
DUMMY_PKV_GENERATOR_CLASS = ChatGLM2DummyPastKeyValuesGenerator
148+
no_position_ids = False
149+
150+
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
151+
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)
152+
153+
dummy_inputs = {}
154+
input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")]
155+
if self.use_past_in_inputs and self.use_cache_branch is not False:
156+
input_names.append("past_key_values")
157+
158+
for input_name in input_names:
159+
input_was_inserted = False
160+
for dummy_input_gen in dummy_inputs_generators:
161+
if dummy_input_gen.supports_input(input_name):
162+
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
163+
dummy_input_gen,
164+
input_name,
165+
framework,
166+
input_shapes=kwargs,
167+
)
168+
input_was_inserted = True
169+
break
170+
if not input_was_inserted:
171+
raise RuntimeError(
172+
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
173+
)
174+
175+
# refer to https://github.com/huggingface/optimum/pull/764
176+
cond1 = self.use_past_in_inputs
177+
cond2 = self.PAD_ATTENTION_MASK_TO_PAST
178+
cond3 = self.use_cache_branch is not False
179+
cond4 = "attention_mask" in dummy_inputs
180+
if cond1 and cond2 and cond3 and cond4:
181+
# Obtain the past sequence length from the value instead of the key (Bloom).
182+
past_length = dummy_inputs["past_key_values"][0][1].shape[0]
183+
for k, v in dummy_inputs.items():
184+
if k not in ["attention_mask", "past_key_values"]:
185+
dummy_inputs[k] = v[:, -1:]
186+
187+
dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
188+
dummy_inputs["attention_mask"],
189+
desired_length=past_length + 1,
190+
dim=1,
191+
dtype=dummy_inputs["attention_mask"].dtype,
192+
)
193+
194+
return dummy_inputs
195+
196+
@property
197+
def inputs(self) -> Dict[str, Dict[int, str]]:
198+
common_inputs = super().inputs
199+
if not self.no_position_ids and self.task == "text-generation":
200+
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
201+
202+
return common_inputs
203+
204+
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
205+
"""
206+
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
207+
208+
Args:
209+
inputs_or_outputs (`Dict[str, Dict[int, str]]`): The mapping to fill.
210+
direction (`str`):
211+
either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
212+
output mapping, this is important for axes naming.
213+
"""
214+
if direction not in ["inputs", "outputs"]:
215+
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
216+
217+
if direction == "inputs":
218+
decoder_sequence_name = "past_sequence_length"
219+
name = "past_key_values"
220+
else:
221+
decoder_sequence_name = "past_sequence_length + 1"
222+
name = "present"
223+
224+
for i in range(self._normalized_config.num_layers):
225+
inputs_or_outputs[f"{name}.{i}.key"] = {1: "batch_size", 0: decoder_sequence_name}
226+
inputs_or_outputs[f"{name}.{i}.value"] = {1: "batch_size", 0: decoder_sequence_name}
227+
228+
def patch_model_for_export(
229+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
230+
) -> "ModelPatcher":
231+
return ChatGLMModelPatcher(self, model, model_kwargs=model_kwargs)
232+
233+
73234
@register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"])
74235
class MixtralOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
75236
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35

optimum/exporters/openvino/model_patcher.py

+125
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414

1515
import logging as log
1616
import types
17+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
1718

1819
import torch
1920
import torch.nn.functional as F
21+
from transformers.modeling_outputs import BaseModelOutputWithPast
22+
from transformers.utils import is_tf_available
2023

2124
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher
2225
from optimum.intel.utils.import_utils import (
@@ -27,6 +30,15 @@
2730
)
2831

2932

33+
if TYPE_CHECKING:
34+
from transformers.modeling_utils import PreTrainedModel
35+
36+
from optimum.exporters.onnx.config import OnnxConfig
37+
38+
if is_tf_available():
39+
from transformers.modeling_tf_utils import TFPreTrainedModel
40+
41+
3042
def patch_model_with_bettertransformer(model):
3143
# check that the model has not yet been pathced
3244
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
@@ -107,3 +119,116 @@ def __exit__(self, exc_type, exc_value, traceback):
107119
super().__exit__(exc_type, exc_value, traceback)
108120
for layer in self._model.model.layers:
109121
layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward
122+
123+
124+
def _chatglm_transformer_forward(
125+
self,
126+
input_ids,
127+
position_ids: Optional[torch.Tensor] = None,
128+
attention_mask: Optional[torch.BoolTensor] = None,
129+
full_attention_mask: Optional[torch.BoolTensor] = None,
130+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
131+
inputs_embeds: Optional[torch.Tensor] = None,
132+
use_cache: Optional[bool] = None,
133+
output_hidden_states: Optional[bool] = None,
134+
return_dict: Optional[bool] = None,
135+
):
136+
output_hidden_states = (
137+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
138+
)
139+
use_cache = use_cache if use_cache is not None else self.config.use_cache
140+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
141+
142+
batch_size, seq_length = input_ids.shape
143+
144+
if inputs_embeds is None:
145+
inputs_embeds = self.embedding(input_ids)
146+
147+
if self.pre_seq_len is not None:
148+
if past_key_values is None:
149+
past_key_values = self.get_prompt(
150+
batch_size=batch_size,
151+
device=input_ids.device,
152+
dtype=inputs_embeds.dtype,
153+
)
154+
if attention_mask is not None:
155+
attention_mask = torch.cat(
156+
[
157+
attention_mask.new_ones((batch_size, self.pre_seq_len)),
158+
attention_mask,
159+
],
160+
dim=-1,
161+
)
162+
163+
if full_attention_mask is None:
164+
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
165+
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
166+
elif past_key_values is not None:
167+
full_attention_mask = torch.ones(
168+
batch_size,
169+
seq_length,
170+
seq_length,
171+
device=input_ids.device,
172+
dtype=torch.float,
173+
) * float("-inf")
174+
full_attention_mask.triu_(diagonal=1)
175+
past_length = 0
176+
if past_key_values:
177+
past_length = past_key_values[0][0].shape[0]
178+
if past_length:
179+
full_attention_mask = torch.cat(
180+
(
181+
torch.zeros(batch_size, seq_length, past_length, device=input_ids.device),
182+
full_attention_mask,
183+
),
184+
dim=-1,
185+
)
186+
full_attention_mask.unsqueeze_(1)
187+
188+
# Rotary positional embeddings
189+
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
190+
if position_ids is not None:
191+
rotary_pos_emb = rotary_pos_emb[position_ids]
192+
else:
193+
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
194+
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
195+
196+
# Run encoder.
197+
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
198+
inputs_embeds,
199+
full_attention_mask,
200+
rotary_pos_emb=rotary_pos_emb,
201+
kv_caches=past_key_values,
202+
use_cache=use_cache,
203+
output_hidden_states=output_hidden_states,
204+
)
205+
206+
if not return_dict:
207+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
208+
209+
return BaseModelOutputWithPast(
210+
last_hidden_state=hidden_states,
211+
past_key_values=presents,
212+
hidden_states=all_hidden_states,
213+
attentions=all_self_attentions,
214+
)
215+
216+
217+
class ChatGLMModelPatcher(DecoderModelPatcher):
218+
def __init__(
219+
self,
220+
config: "OnnxConfig",
221+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
222+
model_kwargs: Dict[str, Any],
223+
):
224+
super().__init__(config, model, model_kwargs)
225+
226+
self.original_chatglm_transformer_forward = model.transformer.forward
227+
228+
def __enter__(self):
229+
super().__enter__()
230+
self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer)
231+
232+
def __exit__(self, exc_type, exc_value, traceback):
233+
super().__exit__(exc_type, exc_value, traceback)
234+
self._model.transformer.forward = self.original_chatglm_transformer_forward

0 commit comments

Comments
 (0)