Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rework inputs preparation for OVModelForCausalLM #620

Merged
merged 5 commits into from
Apr 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 40 additions & 56 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
@@ -120,6 +120,7 @@ def __init__(
self._original_model = self.model.clone() # keep original model for serialization
self._pkv_precision = Type.f32
self.next_beam_idx = None
self._past_length = 0
self.update_pkv_precision()
if self.is_dynamic:
self.model = self._reshape(self.model, -1, -1)
@@ -356,19 +357,14 @@ def prepare_inputs(
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> Dict:
if self.use_cache and past_key_values is not None:
input_ids = input_ids[:, -1:]

batch_size = input_ids.shape[0]
if self.config.model_type == "bloom":
batch_size *= self.config.num_attention_heads

inputs = {}
past_len = 0
if not self.stateful:
if past_key_values is not None:
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
past_len = past_key_values[0][1].shape[-2]
if self._pkv_precision == Type.bf16:
# numpy does not support bf16, pretending f16, should change to bf16
past_key_values = tuple(
@@ -381,8 +377,6 @@ def prepare_inputs(
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
else:
past_len = past_key_values[0].shape[-2]

# Add the past_key_values to the decoder inputs
inputs = dict(zip(self.key_value_input_names, past_key_values))
@@ -411,6 +405,8 @@ def prepare_inputs(
# Set initial value for the next beam_idx input that will be used at the current iteration
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
self.next_beam_idx = np.arange(batch_size, dtype=int)
self._past_length = 0
past_len = self._get_past_length(past_key_values)

inputs["input_ids"] = np.array(input_ids)
# Add the attention_mask inputs when needed
@@ -432,7 +428,7 @@ def prepare_inputs(
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_key_values:
position_ids = np.expand_dims(position_ids[:, -1], axis=-1)
position_ids = position_ids[:, -input_ids.shape[1] :]

inputs["position_ids"] = position_ids

@@ -470,6 +466,7 @@ def forward(
# the first condition at the function beginning above.
# It should be something that is not None and it should be True when converted to Boolean.
past_key_values = ((),)
self._past_length += input_ids.shape[1]

if not self.stateful:
if self.use_cache:
@@ -485,19 +482,32 @@ def forward(

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
# Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

if past_key_values is not None:
past_len = self._get_past_length(past_key_values)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_len < input_ids.shape[1]:
input_ids = input_ids[:, past_len:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
position_ids = position_ids[:, -input_ids.shape[1] :]

return {
"input_ids": input_ids,
@@ -507,6 +517,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
"attention_mask": attention_mask,
}

def _get_past_length(self, past_key_values=None):
if past_key_values is None:
return 0
if self.stateful:
return self._past_length
if self.config.model_type in MULTI_QUERY_ATTN_MODELS:
return past_key_values[0].shape[-2]
seq_length_dim = -2
if self.config.model_type == "chatglm":
seq_length_dim = 0
elif self.config.model_type == "qwen":
seq_length_dim = 1
# input is tuple of pairs
if isinstance(past_key_values[0], (tuple, list)):
return past_key_values[0][1].shape[seq_length_dim]
# past key values comes after flattening
return past_key_values[1].shape[seq_length_dim]

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
@@ -573,10 +601,6 @@ def _from_pretrained(
model_type = config.model_type.replace("_", "-")
if model_type == "bloom":
init_cls = OVBloomForCausalLM
elif model_type == "mpt":
init_cls = OVMPTForCausalLM
elif model_type == "opt":
init_cls = OVOPTForCausalLM
elif model_type == "gpt-bigcode":
init_cls = OVGPTBigCodeForCausalLM
else:
@@ -630,22 +654,12 @@ def _from_pretrained(
class OVBloomForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

# only last token for input_ids if past is not None
if past_key_values and not self.stateful:
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}
return super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, **kwargs)

# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
def _reorder_cache(
@@ -712,36 +726,6 @@ def _convert_to_standard_cache(
)


class OVOPTForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}


class OVMPTForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}


class OVGPTBigCodeForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
def _reorder_cache(
5 changes: 5 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
@@ -632,6 +632,11 @@ def test_multiple_inputs(self, model_arch):
outputs = model.generate(**tokens, generation_config=generation_config)
self.assertIsInstance(outputs, torch.Tensor)
self.assertEqual(outputs.shape[0], 3)
# test that generation result is reproducible
outputs2 = model.generate(**tokens, generation_config=generation_config)
self.assertIsInstance(outputs2, torch.Tensor)
self.assertEqual(outputs2.shape[0], 3)
self.assertTrue(torch.allclose(outputs2, outputs))
del model
gc.collect()