|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -from typing import TYPE_CHECKING, Any, Dict, Optional, Union |
| 15 | +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union |
16 | 16 |
|
17 | 17 | from packaging import version
|
18 | 18 | from transformers.utils import is_tf_available
|
19 | 19 |
|
20 | 20 | from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
|
21 | 21 | from optimum.exporters.tasks import TasksManager
|
22 |
| -from optimum.utils.input_generators import DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator |
| 22 | +from optimum.utils import DEFAULT_DUMMY_SHAPES |
| 23 | +from optimum.utils.input_generators import ( |
| 24 | + DummyInputGenerator, |
| 25 | + DummyPastKeyValuesGenerator, |
| 26 | + DummyTextInputGenerator, |
| 27 | + MistralDummyPastKeyValuesGenerator, |
| 28 | +) |
23 | 29 | from optimum.utils.normalized_config import NormalizedTextConfig
|
24 | 30 |
|
25 |
| -from .model_patcher import MixtralModelPatcher |
| 31 | +from .model_patcher import ChatGLMModelPatcher, MixtralModelPatcher |
26 | 32 |
|
27 | 33 |
|
28 | 34 | if TYPE_CHECKING:
|
@@ -70,6 +76,161 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
|
70 | 76 | NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
|
71 | 77 |
|
72 | 78 |
|
| 79 | +class ChatGLM2DummyTextInputGenerator(DummyTextInputGenerator): |
| 80 | + SUPPORTED_INPUT_NAMES = { |
| 81 | + "input_ids", |
| 82 | + "attention_mask", |
| 83 | + "token_type_ids", |
| 84 | + "position_ids", |
| 85 | + } |
| 86 | + |
| 87 | + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): |
| 88 | + import torch |
| 89 | + |
| 90 | + input = super().generate(input_name, framework, int_dtype, float_dtype) |
| 91 | + if input_name == "attention_mask": |
| 92 | + input = torch.ones(input.shape, dtype=input.dtype) |
| 93 | + if input_name == "position_ids": |
| 94 | + bs = input.shape[0] |
| 95 | + input = torch.range(0, input.shape[1], dtype=input.dtype).repeat(bs, 1) |
| 96 | + return input |
| 97 | + |
| 98 | + |
| 99 | +class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): |
| 100 | + def __init__( |
| 101 | + self, |
| 102 | + task: str, |
| 103 | + normalized_config: NormalizedTextConfig, |
| 104 | + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], |
| 105 | + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], |
| 106 | + random_batch_size_range: Optional[Tuple[int, int]] = None, |
| 107 | + random_sequence_length_range: Optional[Tuple[int, int]] = None, |
| 108 | + **kwargs, |
| 109 | + ): |
| 110 | + super().__init__( |
| 111 | + task=task, |
| 112 | + normalized_config=normalized_config, |
| 113 | + batch_size=batch_size, |
| 114 | + sequence_length=sequence_length, |
| 115 | + random_batch_size_range=random_batch_size_range, |
| 116 | + random_sequence_length_range=random_sequence_length_range, |
| 117 | + ) |
| 118 | + self.multi_query_group_num = normalized_config.multi_query_group_num |
| 119 | + self.head_dim = self.hidden_size // self.num_attention_heads |
| 120 | + |
| 121 | + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): |
| 122 | + past_key_shape = ( |
| 123 | + self.sequence_length, |
| 124 | + self.batch_size, |
| 125 | + self.multi_query_group_num, |
| 126 | + self.head_dim, |
| 127 | + ) |
| 128 | + past_value_shape = ( |
| 129 | + self.sequence_length, |
| 130 | + self.batch_size, |
| 131 | + self.multi_query_group_num, |
| 132 | + self.head_dim, |
| 133 | + ) |
| 134 | + return [ |
| 135 | + ( |
| 136 | + self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), |
| 137 | + self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), |
| 138 | + ) |
| 139 | + for _ in range(self.num_layers) |
| 140 | + ] |
| 141 | + |
| 142 | + |
| 143 | +@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"]) |
| 144 | +class ChatGLM2OpenVINOConfig(TextDecoderOnnxConfig): |
| 145 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(vocab_size="padded_vocab_size", num_layers="num_layers") |
| 146 | + DUMMY_INPUT_GENERATOR_CLASSES = (ChatGLM2DummyTextInputGenerator, ChatGLM2DummyPastKeyValuesGenerator) |
| 147 | + DUMMY_PKV_GENERATOR_CLASS = ChatGLM2DummyPastKeyValuesGenerator |
| 148 | + no_position_ids = False |
| 149 | + |
| 150 | + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): |
| 151 | + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) |
| 152 | + |
| 153 | + dummy_inputs = {} |
| 154 | + input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")] |
| 155 | + if self.use_past_in_inputs and self.use_cache_branch is not False: |
| 156 | + input_names.append("past_key_values") |
| 157 | + |
| 158 | + for input_name in input_names: |
| 159 | + input_was_inserted = False |
| 160 | + for dummy_input_gen in dummy_inputs_generators: |
| 161 | + if dummy_input_gen.supports_input(input_name): |
| 162 | + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( |
| 163 | + dummy_input_gen, |
| 164 | + input_name, |
| 165 | + framework, |
| 166 | + input_shapes=kwargs, |
| 167 | + ) |
| 168 | + input_was_inserted = True |
| 169 | + break |
| 170 | + if not input_was_inserted: |
| 171 | + raise RuntimeError( |
| 172 | + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' |
| 173 | + ) |
| 174 | + |
| 175 | + # refer to https://github.com/huggingface/optimum/pull/764 |
| 176 | + cond1 = self.use_past_in_inputs |
| 177 | + cond2 = self.PAD_ATTENTION_MASK_TO_PAST |
| 178 | + cond3 = self.use_cache_branch is not False |
| 179 | + cond4 = "attention_mask" in dummy_inputs |
| 180 | + if cond1 and cond2 and cond3 and cond4: |
| 181 | + # Obtain the past sequence length from the value instead of the key (Bloom). |
| 182 | + past_length = dummy_inputs["past_key_values"][0][1].shape[0] |
| 183 | + for k, v in dummy_inputs.items(): |
| 184 | + if k not in ["attention_mask", "past_key_values"]: |
| 185 | + dummy_inputs[k] = v[:, -1:] |
| 186 | + |
| 187 | + dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( |
| 188 | + dummy_inputs["attention_mask"], |
| 189 | + desired_length=past_length + 1, |
| 190 | + dim=1, |
| 191 | + dtype=dummy_inputs["attention_mask"].dtype, |
| 192 | + ) |
| 193 | + |
| 194 | + return dummy_inputs |
| 195 | + |
| 196 | + @property |
| 197 | + def inputs(self) -> Dict[str, Dict[int, str]]: |
| 198 | + common_inputs = super().inputs |
| 199 | + if not self.no_position_ids and self.task == "text-generation": |
| 200 | + common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} |
| 201 | + |
| 202 | + return common_inputs |
| 203 | + |
| 204 | + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): |
| 205 | + """ |
| 206 | + Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. |
| 207 | +
|
| 208 | + Args: |
| 209 | + inputs_or_outputs (`Dict[str, Dict[int, str]]`): The mapping to fill. |
| 210 | + direction (`str`): |
| 211 | + either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the |
| 212 | + output mapping, this is important for axes naming. |
| 213 | + """ |
| 214 | + if direction not in ["inputs", "outputs"]: |
| 215 | + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') |
| 216 | + |
| 217 | + if direction == "inputs": |
| 218 | + decoder_sequence_name = "past_sequence_length" |
| 219 | + name = "past_key_values" |
| 220 | + else: |
| 221 | + decoder_sequence_name = "past_sequence_length + 1" |
| 222 | + name = "present" |
| 223 | + |
| 224 | + for i in range(self._normalized_config.num_layers): |
| 225 | + inputs_or_outputs[f"{name}.{i}.key"] = {1: "batch_size", 0: decoder_sequence_name} |
| 226 | + inputs_or_outputs[f"{name}.{i}.value"] = {1: "batch_size", 0: decoder_sequence_name} |
| 227 | + |
| 228 | + def patch_model_for_export( |
| 229 | + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
| 230 | + ) -> "ModelPatcher": |
| 231 | + return ChatGLMModelPatcher(self, model, model_kwargs=model_kwargs) |
| 232 | + |
| 233 | + |
73 | 234 | @register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"])
|
74 | 235 | class MixtralOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
|
75 | 236 | # This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
|
|
0 commit comments