Skip to content

Commit 1ee4535

Browse files
committed
add test
1 parent 878921b commit 1ee4535

File tree

2 files changed

+95
-14
lines changed

2 files changed

+95
-14
lines changed

optimum/intel/openvino/modeling_decoder.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -1407,22 +1407,22 @@ def _group_beam_search(
14071407
beam_next_tokens = beam_outputs["next_beam_tokens"]
14081408
beam_idx = beam_outputs["next_beam_indices"]
14091409

1410-
if return_dict_in_generate and output_scores:
1411-
beam_indices[beam_group_idx] = tuple(
1412-
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
1413-
)
1410+
if return_dict_in_generate and output_scores:
1411+
beam_indices[beam_group_idx] = tuple(
1412+
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
1413+
)
14141414

1415-
input_ids[batch_group_indices] = group_input_ids[beam_idx]
1416-
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
1417-
current_tokens[batch_group_indices] = group_input_ids[:, -1]
1415+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
1416+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
1417+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
14181418

1419-
# (beam_idx // group_size) -> batch_idx
1420-
# (beam_idx % group_size) -> offset of idx inside the group
1421-
reordering_indices[batch_group_indices] = (
1422-
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor")
1423-
+ group_start_idx
1424-
+ (beam_idx % group_size)
1425-
)
1419+
# (beam_idx // group_size) -> batch_idx
1420+
# (beam_idx % group_size) -> offset of idx inside the group
1421+
reordering_indices[batch_group_indices] = (
1422+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor")
1423+
+ group_start_idx
1424+
+ (beam_idx % group_size)
1425+
)
14261426

14271427
# Store scores, attentions and hidden_states when required
14281428
if return_dict_in_generate:

tests/openvino/test_modeling.py

+81
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,87 @@ def test_default_filling_attention_mask_and_position_ids(self):
757757
del model_with_cache
758758
gc.collect()
759759

760+
def test_beam_search(self):
761+
model_id = MODEL_NAMES["llama"]
762+
ov_model_stateful = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=True)
763+
ov_model_stateless = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=False)
764+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
765+
766+
tokenizer = AutoTokenizer.from_pretrained(model_id)
767+
tokenizer.pad_token = tokenizer.eos_token
768+
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
769+
ov_model_stateful.generation_config.eos_token_id = None
770+
ov_model_stateless.generation_config.eos_token_id = None
771+
transformers_model.generation_config.eos_token_id = None
772+
ov_model_stateful.config.eos_token_id = None
773+
ov_model_stateless.config.eos_token_id = None
774+
transformers_model.config.eos_token_id = None
775+
776+
# beam search
777+
gen_config = GenerationConfig(
778+
max_new_tokens=10,
779+
min_new_tokens=10,
780+
num_beams=4,
781+
do_sample=False,
782+
eos_token_id=None,
783+
)
784+
785+
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
786+
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
787+
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
788+
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
789+
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
790+
# beam sample
791+
gen_config = GenerationConfig(
792+
max_new_tokens=10,
793+
min_new_tokens=10,
794+
num_beams=4,
795+
do_sample=True,
796+
eos_token_id=None,
797+
top_k=1,
798+
)
799+
800+
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
801+
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
802+
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
803+
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
804+
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
805+
806+
# group beam search
807+
gen_config = GenerationConfig(
808+
max_new_tokens=10,
809+
min_new_tokens=10,
810+
num_beams=4,
811+
do_sample=False,
812+
eos_token_id=None,
813+
num_beam_groups=2,
814+
diversity_penalty=0.0000001,
815+
)
816+
817+
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
818+
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
819+
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
820+
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
821+
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
822+
823+
# constrained beam search
824+
force_word = "cat"
825+
force_words_ids = [tokenizer([force_word], add_special_tokens=False).input_ids]
826+
gen_config = GenerationConfig(
827+
max_new_tokens=10,
828+
min_new_tokens=10,
829+
num_beams=4,
830+
do_sample=False,
831+
eos_token_id=None,
832+
force_words_ids=force_words_ids,
833+
)
834+
835+
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
836+
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
837+
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
838+
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
839+
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
840+
760841

761842
class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
762843
SUPPORTED_ARCHITECTURES = (

0 commit comments

Comments
 (0)