Skip to content

Commit 67ef640

Browse files
committed
refactor applying code style with preserve logic for olmo
1 parent 8273de7 commit 67ef640

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

optimum/intel/openvino/modeling_decoder.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,8 @@ def prepare_inputs(
386386
inputs = {}
387387
if not self.stateful:
388388
if past_key_values is not None:
389-
if (
390-
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
391-
or self.config.model_type == "falcon"
392-
and self.config.new_decoder_architecture
389+
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
390+
self.config.model_type == "falcon" and self.config.new_decoder_architecture
393391
):
394392
if self._pkv_precision == Type.bf16:
395393
# numpy does not support bf16, pretending f16, should change to bf16
@@ -499,10 +497,8 @@ def forward(
499497
if self.use_cache:
500498
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
501499
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
502-
if (
503-
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
504-
or self.config.model_type == "falcon"
505-
and self.config.new_decoder_architecture
500+
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
501+
self.config.model_type == "falcon" and self.config.new_decoder_architecture
506502
):
507503
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
508504
past_key_values = tuple(

tests/openvino/test_modeling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ def test_beam_search(self, model_arch):
876876
transformers_model.config.eos_token_id = None
877877

878878
for gen_config in gen_configs:
879-
if gen_config.do_sample and model_arch == "baichuan2-13b":
879+
if gen_config.do_sample and model_arch in ["baichuan2-13b", "olmo"]:
880880
continue
881881
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
882882
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)

0 commit comments

Comments
 (0)