Skip to content

Commit b21d14d

Browse files
authored
Fix bloom generation (#736)
* Fix bloom generation * remove unused variable * add style * add message error * update model id
1 parent 5dfbcbc commit b21d14d

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

optimum/intel/openvino/modeling_decoder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -587,11 +587,11 @@ def _deduplicate_inputs(self, model_inputs: Dict):
587587
)
588588
for input_name, input_tensor in model_inputs.items():
589589
if input_name not in ["input_ids", "beam_idx"]:
590-
if not isinstance(input_tensor, Tensor):
590+
if input_name not in self.key_value_input_names:
591591
upd_model_inputs[input_name] = input_tensor[indicies]
592592
else:
593-
shape = input_tensor.shape
594-
dtype = input_tensor.element_type
593+
shape = input_tensor.shape if isinstance(input_tensor, Tensor) else list(input_tensor.shape)
594+
dtype = input_tensor.element_type if isinstance(input_tensor, Tensor) else Type(input_tensor.dtype)
595595
upd_batch_size = indicies.shape[0]
596596
if self.config.model_type == "bloom":
597597
upd_batch_size *= self.config.num_attention_heads

tests/openvino/test_modeling.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -879,14 +879,14 @@ def test_beam_search(self, model_arch):
879879
ov_model_stateless.config.eos_token_id = None
880880
transformers_model.config.eos_token_id = None
881881

882-
for gen_config in gen_configs:
882+
for idx, gen_config in enumerate(gen_configs):
883883
if gen_config.do_sample and model_arch in ["baichuan2-13b", "olmo"]:
884884
continue
885885
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
886886
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
887-
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
887+
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs), f"generation config : {idx}")
888888
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
889-
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
889+
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs), f"generation config : {idx}")
890890

891891

892892
class OVModelForMaskedLMIntegrationTest(unittest.TestCase):

tests/openvino/utils_tests.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
"internlm2": "katuni4ka/tiny-random-internlm2",
6666
"levit": "hf-internal-testing/tiny-random-LevitModel",
6767
"longt5": "hf-internal-testing/tiny-random-longt5",
68-
"llama": "fxmarty/tiny-llama-fast-tokenizer",
68+
"llama": "HuggingFaceM4/tiny-random-LlamaForCausalLM",
6969
"llama_awq": "HuggingFaceH4/tiny-random-LlamaForCausalLM",
7070
"llama_gptq": "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ",
7171
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",

0 commit comments

Comments
 (0)