Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix beam search test reported issues #718

Merged
merged 4 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,8 @@ def prepare_inputs(
inputs = {}
if not self.stateful:
if past_key_values is not None:
if (
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
or self.config.model_type == "falcon"
and self.config.new_decoder_architecture
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
if self._pkv_precision == Type.bf16:
# numpy does not support bf16, pretending f16, should change to bf16
Expand Down Expand Up @@ -499,10 +497,8 @@ def forward(
if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
if (
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
or self.config.model_type == "falcon"
and self.config.new_decoder_architecture
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
past_key_values = tuple(
Expand Down Expand Up @@ -559,10 +555,8 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke
if indicies.shape[0] != 1:
logits = logits[indicies]
if past_key_values and not self.stateful:
if (
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
or self.config.model_type == "falcon"
and self.config.new_decoder_architecture
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
past_key_values = tuple(
tuple(
Expand All @@ -581,7 +575,7 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke
if self.next_beam_idx is not None
else np.arange(batch_size, dtype=int)[indicies]
)
self._second_iter_beam_search = True
self._second_iter_beam_search = True
return logits, past_key_values

def _deduplicate_inputs(self, model_inputs: Dict):
Expand Down Expand Up @@ -692,7 +686,7 @@ def _reorder_cache(
self._second_iter_beam_search = False
return past_key_values
else:
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS and not (
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
return tuple(
Expand Down
10 changes: 8 additions & 2 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,8 +797,8 @@ def test_default_filling_attention_mask_and_position_ids(self):
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@pytest.mark.run_slow
@slow
# @pytest.mark.run_slow
# @slow
def test_beam_search(self, model_arch):
model_kwargs = {}
model_id = MODEL_NAMES[model_arch]
Expand All @@ -812,6 +812,10 @@ def test_beam_search(self, model_arch):
return

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
if model_arch == "persimmon":
tokenizer.pad_token_id = tokenizer.bos_token_id
tokenizer.eos_token_id = tokenizer.bos_token_id

beam_search_gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
Expand Down Expand Up @@ -872,6 +876,8 @@ def test_beam_search(self, model_arch):
transformers_model.config.eos_token_id = None

for gen_config in gen_configs:
if gen_config.do_sample and model_arch in ["baichuan2-13b", "olmo"]:
continue
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
Expand Down
Loading