1
+ # Copyright (C) 2018-2024 Intel Corporation
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import os
5
+ import pytest
6
+
7
+ from optimum .intel import OVModelForCausalLM
8
+ from pathlib import Path
9
+ from py_continuous_batching import ContinuousBatchingPipeline , GenerationConfig , SchedulerConfig , GenerationResult
10
+ from transformers import AutoTokenizer , AutoModelForCausalLM
11
+ from transformers import GenerationConfig as HFGenerationConfig
12
+ from typing import List , Tuple
13
+
14
+
15
+ def get_greedy () -> GenerationConfig :
16
+ generation_config = GenerationConfig ()
17
+ generation_config .num_return_sequences = 1
18
+ return generation_config
19
+
20
+ def get_beam_search () -> GenerationConfig :
21
+ generation_config = GenerationConfig ()
22
+ generation_config .num_groups = 3
23
+ generation_config .group_size = 2
24
+ generation_config .max_new_tokens = 30
25
+ generation_config .num_return_sequences = generation_config .num_groups * generation_config .group_size
26
+ return generation_config
27
+
28
+ def get_test_dataset () -> Tuple [List [str ], List [GenerationConfig ]]:
29
+ prompts = [
30
+ "What is OpenVINO?" ,
31
+ "How are you?" ,
32
+ "What is your name?" ,
33
+ "Tell me something about Canada"
34
+ ]
35
+ generation_configs = [
36
+ get_greedy (),
37
+ get_beam_search (),
38
+ get_greedy (),
39
+ get_beam_search ()
40
+ ]
41
+ return (prompts , generation_configs )
42
+
43
+ def get_scheduler_config (scheduler_params : dict = None ) -> SchedulerConfig :
44
+ scheduler_config = SchedulerConfig ()
45
+ if scheduler_params is None :
46
+ scheduler_config .dynamic_split_fuse = True
47
+ scheduler_config .num_kv_blocks = 300
48
+ # vLLM specific
49
+ scheduler_config .max_num_batched_tokens = 256
50
+ scheduler_config .max_num_seqs = 256
51
+ else :
52
+ for param , value in scheduler_params .items ():
53
+ setattr (scheduler_config , param , value )
54
+
55
+ return scheduler_config
56
+
57
+ def convert_to_hf (
58
+ default_generation_config : HFGenerationConfig ,
59
+ generation_config : GenerationConfig
60
+ ) -> HFGenerationConfig :
61
+ kwargs = {}
62
+
63
+ # generic parameters
64
+ kwargs ['max_length' ] = generation_config .max_length
65
+ kwargs ['max_new_tokens' ] = generation_config .max_new_tokens
66
+
67
+ # copy default parameters
68
+ kwargs ['eos_token_id' ] = default_generation_config .eos_token_id
69
+ kwargs ['pad_token_id' ] = default_generation_config .pad_token_id
70
+
71
+ if generation_config .num_groups * generation_config .group_size > 1 :
72
+ # beam search case
73
+ kwargs ['num_beam_groups' ] = generation_config .num_groups
74
+ kwargs ['num_beams' ] = generation_config .num_groups * generation_config .group_size
75
+ kwargs ['diversity_penalty' ] = generation_config .diversity_penalty
76
+ kwargs ['repetition_penalty' ] = generation_config .repetition_penalty
77
+ kwargs ['length_penalty' ] = generation_config .length_penalty
78
+ kwargs ['no_repeat_ngram_size' ] = generation_config .no_repeat_ngram_size
79
+ kwargs ['num_return_sequences' ] = generation_config .num_return_sequences
80
+ kwargs ['output_scores' ] = True
81
+ elif generation_config .do_sample :
82
+ # mulitinomial
83
+ kwargs ['temperature' ] = generation_config .temperature
84
+ kwargs ['top_k' ] = generation_config .top_k
85
+ kwargs ['top_p' ] = generation_config .top_p
86
+ kwargs ['do_sample' ] = generation_config .do_sample
87
+ else :
88
+ # greedy
89
+ pass
90
+
91
+ hf_generation_config = HFGenerationConfig (** kwargs )
92
+ return hf_generation_config
93
+
94
+ def run_hugging_face (
95
+ model_id : str ,
96
+ prompts : List [str ],
97
+ generation_configs : List [GenerationConfig ],
98
+ use_optimum : bool ,
99
+ tmp_path : Path
100
+ ) -> Tuple [List [GenerationResult ], str ]:
101
+ hf_tokenizer = AutoTokenizer .from_pretrained (model_id )
102
+ model = OVModelForCausalLM .from_pretrained (model_id , export = True ) if use_optimum else \
103
+ AutoModelForCausalLM .from_pretrained (model_id )
104
+ generation_results : List [GenerationResult ] = []
105
+ model_path : Path = tmp_path / model_id
106
+
107
+ if use_optimum :
108
+ model .save_pretrained (model_path )
109
+ # convert tokenizers as well
110
+ from openvino_tokenizers import convert_tokenizer
111
+ from openvino import serialize
112
+ tokenizer , detokenizer = convert_tokenizer (hf_tokenizer , with_detokenizer = True )
113
+ serialize (tokenizer , model_path / "openvino_tokenizer.xml" )
114
+ serialize (detokenizer , model_path / "openvino_detokenizer.xml" )
115
+
116
+ for prompt , generation_config in zip (prompts , generation_configs ):
117
+ inputs = hf_tokenizer (prompt , return_tensors = "pt" )
118
+ prompt_len = len (inputs ['input_ids' ][0 ])
119
+ generate_outputs = model .generate (** inputs , generation_config = convert_to_hf (model .generation_config , generation_config ), return_dict_in_generate = True )
120
+ all_text_batch = hf_tokenizer .batch_decode ([generated_ids [prompt_len :] for generated_ids in generate_outputs .sequences ], skip_special_tokens = True )
121
+
122
+ generation_result = GenerationResult ()
123
+ generation_result .m_generation_ids = all_text_batch
124
+ # sequences_scores are available only for beam search case
125
+ if generation_config .is_beam_search :
126
+ generation_result .m_scores = [score for score in generate_outputs .sequences_scores ]
127
+ generation_results .append (generation_result )
128
+
129
+ return (generation_results , model_path )
130
+
131
+ def run_continuous_batching (
132
+ model_path : Path ,
133
+ scheduler_config : SchedulerConfig ,
134
+ prompts : List [str ],
135
+ generation_configs : List [GenerationConfig ]
136
+ ) -> List [GenerationResult ]:
137
+ pipe = ContinuousBatchingPipeline (model_path .absolute ().as_posix (), scheduler_config )
138
+ return pipe .generate (prompts , generation_configs )
139
+
140
+ def get_models_list (file_name : str ):
141
+ models = []
142
+ with open (file_name ) as f :
143
+ for model_name in f :
144
+ model_name = model_name .strip ()
145
+ # skip comment in model scope file
146
+ if model_name .startswith ('#' ):
147
+ continue
148
+ models .append (model_name )
149
+ return models
150
+
151
+ def compare_results (hf_result , ov_result , generation_config ):
152
+ if generation_config .is_beam_search :
153
+ assert len (hf_result .m_scores ) == len (ov_result .m_scores )
154
+ for hf_score , ov_score in zip (hf_result .m_scores , ov_result .m_scores ):
155
+ # Note, that for fp32 / fp16 models scores are different less than 0.001
156
+ assert abs (hf_score - ov_score ) < 0.02
157
+
158
+ assert len (hf_result .m_generation_ids ) == len (ov_result .m_generation_ids )
159
+ for hf_text , ov_text in zip (hf_result .m_generation_ids , ov_result .m_generation_ids ):
160
+ assert hf_text == ov_text
161
+
162
+
163
+ def run_test_pipeline (tmp_path : str , model_id : str , scheduler_params : dict = None ):
164
+ prompts , generation_configs = get_test_dataset ()
165
+ scheduler_config = get_scheduler_config (scheduler_params )
166
+
167
+ (hf_results , model_path ) = run_hugging_face (model_id = model_id , prompts = prompts ,
168
+ generation_configs = generation_configs , tmp_path = tmp_path ,
169
+ use_optimum = True )
170
+ ov_results : List [GenerationResult ] = run_continuous_batching (model_path , scheduler_config , prompts ,
171
+ generation_configs )
172
+
173
+ assert len (prompts ) == len (hf_results )
174
+ assert len (prompts ) == len (ov_results )
175
+
176
+ for prompt , hf_result , ov_result , generation_config in zip (prompts , hf_results , ov_results , generation_configs ):
177
+ print (f"Prompt = { prompt } \n HF result = { hf_result } \n OV result = { ov_result } " )
178
+ compare_results (hf_result , ov_result , generation_config )
0 commit comments