Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

relax requirements to have registered normalized config for usage con… #537

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from openvino.runtime import opset13
from optimum.exporters import TasksManager
from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version
from optimum.utils.normalized_config import NormalizedConfigManager


def model_has_state(ov_model: ov.Model):
Expand Down Expand Up @@ -217,9 +216,7 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
batch_dim = 1 if config.model_type == "chatglm" else 0

fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)

normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
num_attention_heads = normalized_config.num_attention_heads if config.model_type == "bloom" else 1
num_attention_heads = config.num_attention_heads if config.model_type == "bloom" else 1
make_stateful(
ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None
)
16 changes: 11 additions & 5 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithPast

from optimum.utils import NormalizedConfigManager
from optimum.utils.normalized_config import NormalizedConfigManager

from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful
from ...exporters.openvino.stateful import model_has_state
Expand Down Expand Up @@ -132,7 +132,6 @@ def __init__(
self.stateful = model_has_sinks
self.main_input_name = "input_ids"
self.num_pkv = 2
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
self.key_value_output_names = [key for key in self.output_names if "present" in key]
self._original_model = self.model.clone() # keep original model for serialization
Expand Down Expand Up @@ -321,6 +320,13 @@ def reshape(self, batch_size: int, sequence_length: int):
logger.warning("Static shapes are not supported for causal language model.")
return self

@property
def normalized_config(self):
logger.warning(
"access to normalized_config attribute is deprecated and will be removed in future versions, please use config"
)
return NormalizedConfigManager.get_normalized_config_class(self.config.model_type)(self.config)

def compile(self):
if self.request is None:
super().compile()
Expand Down Expand Up @@ -364,7 +370,7 @@ def forward(

batch_size = input_ids.shape[0]
if self.config.model_type == "bloom":
batch_size *= self.normalized_config.num_attention_heads
batch_size *= self.config.num_attention_heads

inputs = {}
past_len = 0
Expand Down Expand Up @@ -592,8 +598,8 @@ def _reorder_cache(
if self.stateful:
beam_idx = np.array(beam_idx)
batch_size = beam_idx.shape[0]
indices = np.array(range(batch_size * self.normalized_config.num_attention_heads))
indices = indices.reshape([batch_size, self.normalized_config.num_attention_heads])
indices = np.array(range(batch_size * self.config.num_attention_heads))
indices = indices.reshape([batch_size, self.config.num_attention_heads])
self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
return past_key_values
else:
Expand Down