|
27 | 27 | from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
28 | 28 | from transformers.modeling_outputs import CausalLMOutputWithPast
|
29 | 29 |
|
30 |
| -from optimum.utils import NormalizedConfigManager |
| 30 | +from optimum.utils.normalized_config import NormalizedConfigManager |
31 | 31 |
|
32 | 32 | from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful
|
33 | 33 | from ...exporters.openvino.stateful import model_has_state
|
@@ -132,7 +132,6 @@ def __init__(
|
132 | 132 | self.stateful = model_has_sinks
|
133 | 133 | self.main_input_name = "input_ids"
|
134 | 134 | self.num_pkv = 2
|
135 |
| - self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) |
136 | 135 | self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
|
137 | 136 | self.key_value_output_names = [key for key in self.output_names if "present" in key]
|
138 | 137 | self._original_model = self.model.clone() # keep original model for serialization
|
@@ -321,6 +320,13 @@ def reshape(self, batch_size: int, sequence_length: int):
|
321 | 320 | logger.warning("Static shapes are not supported for causal language model.")
|
322 | 321 | return self
|
323 | 322 |
|
| 323 | + @property |
| 324 | + def normalized_config(self): |
| 325 | + logger.warning( |
| 326 | + "access to normalized_config attribute is deprecated and will be removed in future versions, please use config" |
| 327 | + ) |
| 328 | + return NormalizedConfigManager.get_normalized_config_class(self.config.model_type)(self.config) |
| 329 | + |
324 | 330 | def compile(self):
|
325 | 331 | if self.request is None:
|
326 | 332 | super().compile()
|
@@ -364,7 +370,7 @@ def forward(
|
364 | 370 |
|
365 | 371 | batch_size = input_ids.shape[0]
|
366 | 372 | if self.config.model_type == "bloom":
|
367 |
| - batch_size *= self.normalized_config.num_attention_heads |
| 373 | + batch_size *= self.config.num_attention_heads |
368 | 374 |
|
369 | 375 | inputs = {}
|
370 | 376 | past_len = 0
|
@@ -592,8 +598,8 @@ def _reorder_cache(
|
592 | 598 | if self.stateful:
|
593 | 599 | beam_idx = np.array(beam_idx)
|
594 | 600 | batch_size = beam_idx.shape[0]
|
595 |
| - indices = np.array(range(batch_size * self.normalized_config.num_attention_heads)) |
596 |
| - indices = indices.reshape([batch_size, self.normalized_config.num_attention_heads]) |
| 601 | + indices = np.array(range(batch_size * self.config.num_attention_heads)) |
| 602 | + indices = indices.reshape([batch_size, self.config.num_attention_heads]) |
597 | 603 | self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
|
598 | 604 | return past_key_values
|
599 | 605 | else:
|
|
0 commit comments