@@ -40,6 +40,11 @@ def parse_args():
40
40
default = None ,
41
41
help = "Tokenizer for divergency metric. If not provided, it will be load from base_model or target_model." ,
42
42
)
43
+ parser .add_argument (
44
+ "--chat-template" ,
45
+ action = "store_true" ,
46
+ help = "Whether apply the default chat template." ,
47
+ )
43
48
parser .add_argument (
44
49
"--gt-data" ,
45
50
default = None ,
@@ -137,6 +142,11 @@ def parse_args():
137
142
action = "store_true" ,
138
143
help = "Use LLMPipeline from transformers library to instantiate the model." ,
139
144
)
145
+ parser .add_argument (
146
+ "--llamacpp" ,
147
+ action = "store_true" ,
148
+ help = "Use llama-cpp-python to instantiate the model." ,
149
+ )
140
150
parser .add_argument (
141
151
"--image-size" ,
142
152
type = int ,
@@ -190,9 +200,13 @@ def load_prompts(args):
190
200
def load_tokenizer (args ):
191
201
tokenizer = None
192
202
if args .tokenizer is not None :
193
- tokenizer = AutoTokenizer .from_pretrained (
194
- args .tokenizer , trust_remote_code = True
195
- )
203
+ if args .llamacpp :
204
+ from llama_cpp .llama_tokenizer import LlamaHFTokenizer
205
+ tokenizer = LlamaHFTokenizer .from_pretrained (args .tokenizer )
206
+ else :
207
+ tokenizer = AutoTokenizer .from_pretrained (
208
+ args .tokenizer , trust_remote_code = True
209
+ )
196
210
elif args .base_model is not None :
197
211
tokenizer = AutoTokenizer .from_pretrained (
198
212
args .base_model , trust_remote_code = True
@@ -246,8 +260,29 @@ def diff_strings(a: str, b: str, *, use_loguru_colors: bool = False) -> str:
246
260
return "" .join (output )
247
261
248
262
249
- def genai_gen_text (model , tokenizer , question , max_new_tokens , skip_question ):
250
- return model .generate (question , do_sample = False , max_new_tokens = max_new_tokens )
263
+ def genai_gen_text (model , tokenizer , question , max_new_tokens , skip_question , use_chat_template = False ):
264
+ if use_chat_template :
265
+ model .start_chat ()
266
+ result = model .generate (question , do_sample = False , max_new_tokens = max_new_tokens )
267
+ model .finish_chat ()
268
+ return result
269
+ else :
270
+ return model .generate (question , do_sample = False , max_new_tokens = max_new_tokens )
271
+
272
+
273
+ def llamacpp_gen_text (model , tokenizer , question , max_new_tokens , skip_question , use_chat_template = False ):
274
+ if use_chat_template :
275
+ output = model .create_chat_completion (messages = [{"role" : "user" , "content" : question }], max_tokens = max_new_tokens , temperature = 0.0 )
276
+ text = output ["choices" ][0 ]["message" ]["content" ]
277
+ if skip_question :
278
+ text = text [len (question ):]
279
+ return text
280
+ else :
281
+ output = model (question , max_tokens = max_new_tokens , echo = True , temperature = 0.0 )
282
+ text = output ["choices" ][0 ]["text" ]
283
+ if skip_question :
284
+ text = text [len (question ):]
285
+ return text
251
286
252
287
253
288
def genai_gen_image (model , prompt , num_inference_steps , generator = None ):
@@ -322,7 +357,15 @@ def create_evaluator(base_model, args):
322
357
prompts = load_prompts (args )
323
358
324
359
if task == "text" :
325
- tokenizer = load_tokenizer (args )
360
+ tokenizer = load_tokenizer (args ) if not args .llamacpp else None
361
+
362
+ if args .genai :
363
+ gen_answer_fn = genai_gen_text
364
+ elif args .llamacpp :
365
+ gen_answer_fn = llamacpp_gen_text
366
+ else :
367
+ gen_answer_fn = None
368
+
326
369
return EvaluatorCLS (
327
370
base_model = base_model ,
328
371
gt_data = args .gt_data ,
@@ -331,7 +374,8 @@ def create_evaluator(base_model, args):
331
374
similarity_model_id = args .data_encoder ,
332
375
num_samples = args .num_samples ,
333
376
language = args .language ,
334
- gen_answer_fn = genai_gen_text if args .genai else None ,
377
+ gen_answer_fn = gen_answer_fn ,
378
+ use_chat_template = args .chat_template ,
335
379
)
336
380
elif task == "text-to-image" :
337
381
return EvaluatorCLS (
@@ -467,10 +511,11 @@ def main():
467
511
args .ov_config ,
468
512
args .hf ,
469
513
args .genai ,
514
+ args .llamacpp
470
515
)
471
516
all_metrics_per_question , all_metrics = evaluator .score (
472
517
target_model ,
473
- evaluator .get_generation_fn () if args .genai else None ,
518
+ evaluator .get_generation_fn () if args .genai or args . llamacpp else None ,
474
519
output_dir = args .output
475
520
)
476
521
logger .info ("Metrics for model: %s" , args .target_model )
0 commit comments