Skip to content

Commit df8a5c6

Browse files
committed
add test
1 parent f263f3f commit df8a5c6

File tree

2 files changed

+95
-14
lines changed

2 files changed

+95
-14
lines changed

optimum/intel/openvino/modeling_decoder.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -1418,22 +1418,22 @@ def _group_beam_search(
14181418
beam_next_tokens = beam_outputs["next_beam_tokens"]
14191419
beam_idx = beam_outputs["next_beam_indices"]
14201420

1421-
if return_dict_in_generate and output_scores:
1422-
beam_indices[beam_group_idx] = tuple(
1423-
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
1424-
)
1421+
if return_dict_in_generate and output_scores:
1422+
beam_indices[beam_group_idx] = tuple(
1423+
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
1424+
)
14251425

1426-
input_ids[batch_group_indices] = group_input_ids[beam_idx]
1427-
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
1428-
current_tokens[batch_group_indices] = group_input_ids[:, -1]
1426+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
1427+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
1428+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
14291429

1430-
# (beam_idx // group_size) -> batch_idx
1431-
# (beam_idx % group_size) -> offset of idx inside the group
1432-
reordering_indices[batch_group_indices] = (
1433-
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor")
1434-
+ group_start_idx
1435-
+ (beam_idx % group_size)
1436-
)
1430+
# (beam_idx // group_size) -> batch_idx
1431+
# (beam_idx % group_size) -> offset of idx inside the group
1432+
reordering_indices[batch_group_indices] = (
1433+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor")
1434+
+ group_start_idx
1435+
+ (beam_idx % group_size)
1436+
)
14371437

14381438
# Store scores, attentions and hidden_states when required
14391439
if return_dict_in_generate:

tests/openvino/test_modeling.py

+81
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,87 @@ def test_default_filling_attention_mask_and_position_ids(self):
778778
del model_with_cache
779779
gc.collect()
780780

781+
def test_beam_search(self):
782+
model_id = MODEL_NAMES["llama"]
783+
ov_model_stateful = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=True)
784+
ov_model_stateless = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=False)
785+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
786+
787+
tokenizer = AutoTokenizer.from_pretrained(model_id)
788+
tokenizer.pad_token = tokenizer.eos_token
789+
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
790+
ov_model_stateful.generation_config.eos_token_id = None
791+
ov_model_stateless.generation_config.eos_token_id = None
792+
transformers_model.generation_config.eos_token_id = None
793+
ov_model_stateful.config.eos_token_id = None
794+
ov_model_stateless.config.eos_token_id = None
795+
transformers_model.config.eos_token_id = None
796+
797+
# beam search
798+
gen_config = GenerationConfig(
799+
max_new_tokens=10,
800+
min_new_tokens=10,
801+
num_beams=4,
802+
do_sample=False,
803+
eos_token_id=None,
804+
)
805+
806+
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
807+
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
808+
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
809+
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
810+
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
811+
# beam sample
812+
gen_config = GenerationConfig(
813+
max_new_tokens=10,
814+
min_new_tokens=10,
815+
num_beams=4,
816+
do_sample=True,
817+
eos_token_id=None,
818+
top_k=1,
819+
)
820+
821+
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
822+
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
823+
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
824+
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
825+
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
826+
827+
# group beam search
828+
gen_config = GenerationConfig(
829+
max_new_tokens=10,
830+
min_new_tokens=10,
831+
num_beams=4,
832+
do_sample=False,
833+
eos_token_id=None,
834+
num_beam_groups=2,
835+
diversity_penalty=0.0000001,
836+
)
837+
838+
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
839+
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
840+
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
841+
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
842+
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
843+
844+
# constrained beam search
845+
force_word = "cat"
846+
force_words_ids = [tokenizer([force_word], add_special_tokens=False).input_ids]
847+
gen_config = GenerationConfig(
848+
max_new_tokens=10,
849+
min_new_tokens=10,
850+
num_beams=4,
851+
do_sample=False,
852+
eos_token_id=None,
853+
force_words_ids=force_words_ids,
854+
)
855+
856+
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
857+
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
858+
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
859+
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
860+
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
861+
781862

782863
class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
783864
SUPPORTED_ARCHITECTURES = (

0 commit comments

Comments
 (0)