1
1
import datasets
2
2
import os
3
3
import re
4
- from typing import Any
4
+ from typing import Any , Optional
5
5
6
6
from helm .common .general import ensure_directory_exists
7
7
from helm .benchmark .adaptation .request_state import RequestState
@@ -15,7 +15,16 @@ class AIRBench2024Annotator(Annotator):
15
15
16
16
name = "air_bench_2024"
17
17
18
- def __init__ (self , auto_client : AutoClient , file_storage_path : str ):
18
+ _DEFAULT_MODEL = "openai/gpt-4o-2024-05-13"
19
+ _DEFAULT_MODEL_DEPLOYMENT = "openai/gpt-4o-2024-05-13"
20
+
21
+ def __init__ (
22
+ self ,
23
+ auto_client : AutoClient ,
24
+ file_storage_path : str ,
25
+ model : Optional [str ] = None ,
26
+ model_deployment : Optional [str ] = None ,
27
+ ):
19
28
self ._auto_client = auto_client
20
29
cache_dir = os .path .join (file_storage_path , "data" )
21
30
ensure_directory_exists (cache_dir )
@@ -29,6 +38,8 @@ def __init__(self, auto_client: AutoClient, file_storage_path: str):
29
38
self ._category_id_to_judge_prompt = {row ["cate-idx" ]: row ["judge_prompt" ] for row in dataset }
30
39
# Regex pattern is lenient to allow for typos e.g. extra whitespace
31
40
self ._pattern = re .compile ("##\s*short_reasoning\s*:(.*)##\s*the_score\s*:(.*)" , re .DOTALL )
41
+ self ._model = model or self ._DEFAULT_MODEL
42
+ self ._model_deployment = model_deployment or self ._DEFAULT_MODEL_DEPLOYMENT
32
43
33
44
def annotate (self , request_state : RequestState ) -> Any :
34
45
assert request_state .result
@@ -45,8 +56,8 @@ def annotate(self, request_state: RequestState) -> Any:
45
56
"{{ANSWER}}" , model_output_text
46
57
)
47
58
annotator_request = Request (
48
- model = "openai/gpt-4o-2024-05-13" ,
49
- model_deployment = "openai/gpt-4o-2024-05-13" ,
59
+ model = self . _model ,
60
+ model_deployment = self . _model_deployment ,
50
61
prompt = annotator_prompt ,
51
62
temperature = 0.0 ,
52
63
max_tokens = 64 ,
0 commit comments