5
5
#include < cxxopts.hpp>
6
6
7
7
#include " continuous_batching_pipeline.hpp"
8
+ #include " tokenizer.hpp"
8
9
9
10
void print_generation_result (const GenerationResult& generation_result) {
10
11
for (size_t output_id = 0 ; output_id < generation_result.m_generation_ids .size (); ++output_id) {
@@ -46,15 +47,15 @@ int main(int argc, char* argv[]) try {
46
47
std::vector<std::string> prompt_examples = {
47
48
" What is OpenVINO?" ,
48
49
" How are you?" ,
49
- " What is OpenVINO ?" ,
50
- " What is the current time " ,
50
+ " What is your name ?" ,
51
+ " Tell me something about Canada " ,
51
52
" What is OpenVINO?" ,
52
53
};
53
54
54
55
std::vector<GenerationConfig> sampling_params_examples {
55
56
GenerationConfig::beam_search (),
56
- // GenerationConfig::greedy(),
57
- // GenerationConfig::multinomial(),
57
+ GenerationConfig::greedy (),
58
+ GenerationConfig::multinomial (),
58
59
};
59
60
60
61
std::vector<std::string> prompts (num_prompts);
@@ -66,7 +67,7 @@ int main(int argc, char* argv[]) try {
66
67
}
67
68
68
69
// Perform the inference
69
-
70
+
70
71
SchedulerConfig scheduler_config {
71
72
// batch size
72
73
.max_num_batched_tokens = 32 ,
@@ -84,21 +85,20 @@ int main(int argc, char* argv[]) try {
84
85
85
86
for (size_t request_id = 0 ; request_id < generation_results.size (); ++request_id) {
86
87
const GenerationResult & generation_result = generation_results[request_id];
87
-
88
88
std::cout << " Question: " << prompts[request_id] << std::endl;
89
89
switch (generation_result.m_status )
90
90
{
91
- case GenerationResultStatus ::FINISHED:
91
+ case GenerationStatus ::FINISHED:
92
92
print_generation_result (generation_result);
93
93
break ;
94
- case GenerationResultStatus ::IGNORED:
94
+ case GenerationStatus ::IGNORED:
95
95
std::cout << " Request was ignored due to lack of memory." <<std::endl;
96
96
if (generation_result.m_generation_ids .size () > 0 ) {
97
97
std::cout << " Partial result:" << std::endl;
98
98
print_generation_result (generation_result);
99
99
}
100
100
break ;
101
- case GenerationResultStatus::ABORTED :
101
+ case GenerationStatus::DROPPED_BY_PIPELINE :
102
102
std::cout << " Request was aborted." <<std::endl;
103
103
if (generation_result.m_generation_ids .size () > 0 ) {
104
104
std::cout << " Partial result:" << std::endl;
@@ -110,7 +110,6 @@ int main(int argc, char* argv[]) try {
110
110
}
111
111
std::cout << std::endl;
112
112
}
113
-
114
113
} catch (const std::exception & error) {
115
114
std::cerr << error.what () << ' \n ' ;
116
115
return EXIT_FAILURE;
0 commit comments