File tree 2 files changed +5
-9
lines changed
2 files changed +5
-9
lines changed Original file line number Diff line number Diff line change @@ -386,10 +386,8 @@ def prepare_inputs(
386
386
inputs = {}
387
387
if not self .stateful :
388
388
if past_key_values is not None :
389
- if (
390
- self .config .model_type not in MULTI_QUERY_ATTN_MODELS
391
- or self .config .model_type == "falcon"
392
- and self .config .new_decoder_architecture
389
+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS or (
390
+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
393
391
):
394
392
if self ._pkv_precision == Type .bf16 :
395
393
# numpy does not support bf16, pretending f16, should change to bf16
@@ -499,10 +497,8 @@ def forward(
499
497
if self .use_cache :
500
498
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
501
499
past_key_values = tuple (self .request .get_tensor (key ).data for key in self .key_value_output_names )
502
- if (
503
- self .config .model_type not in MULTI_QUERY_ATTN_MODELS
504
- or self .config .model_type == "falcon"
505
- and self .config .new_decoder_architecture
500
+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS or (
501
+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
506
502
):
507
503
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
508
504
past_key_values = tuple (
Original file line number Diff line number Diff line change @@ -876,7 +876,7 @@ def test_beam_search(self, model_arch):
876
876
transformers_model .config .eos_token_id = None
877
877
878
878
for gen_config in gen_configs :
879
- if gen_config .do_sample and model_arch == "baichuan2-13b" :
879
+ if gen_config .do_sample and model_arch in [ "baichuan2-13b" , "olmo" ] :
880
880
continue
881
881
transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
882
882
ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
You can’t perform that action at this time.
0 commit comments