Skip to content

Commit d2a915a

Browse files
authored
Merge branch 'stanford-crfm:main' into main
2 parents 06c0f5b + 4183b44 commit d2a915a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+3534
-497
lines changed

helm-frontend/src/components/AnnotationsDisplay.tsx

+51-22
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,59 @@ import CompletionAnnotation from "@/types/CompletionAnnotation";
22
import Preview from "./Preview";
33
import MediaObjectDisplay from "./MediaObjectDisplay";
44

5+
// TODO: This is a dirty hack to support annotations from
6+
// Image2Structure and AIRBench, but eventually we should make sure
7+
// all annotations are supported generally.
58
type Props = {
69
predictionAnnotations:
7-
| Record<string, Array<CompletionAnnotation>>
10+
| Record<
11+
string,
12+
Array<CompletionAnnotation> | Record<string, string | number>
13+
>
814
| undefined;
915
};
1016

17+
function listAnnotationDisplay(listAnnotation: Array<CompletionAnnotation>) {
18+
return (
19+
<div>
20+
{listAnnotation.map((annotation, idx) => (
21+
<div key={idx}>
22+
{annotation.error && (
23+
<div>
24+
<h3 className="ml-1">Error</h3>
25+
<Preview value={annotation["error"]} />{" "}
26+
</div>
27+
)}
28+
{annotation.text && (
29+
<div>
30+
<h3 className="ml-1">Text</h3>
31+
<Preview value={annotation["text"]} />{" "}
32+
</div>
33+
)}
34+
{annotation.media_object && (
35+
<MediaObjectDisplay mediaObject={annotation["media_object"]} />
36+
)}
37+
</div>
38+
))}
39+
</div>
40+
);
41+
}
42+
43+
function dictAnnotationDisplay(
44+
dictAnnotation: Record<string, string | number>,
45+
) {
46+
return (
47+
<div>
48+
{Object.entries(dictAnnotation).map(([key, value]) => (
49+
<div>
50+
<h3 className="ml-1">{key}</h3>
51+
<Preview value={value.toString()} />
52+
</div>
53+
))}
54+
</div>
55+
);
56+
}
57+
1158
export default function AnnotationDisplay({ predictionAnnotations }: Props) {
1259
return (
1360
<div>
@@ -17,27 +64,9 @@ export default function AnnotationDisplay({ predictionAnnotations }: Props) {
1764
<h3>
1865
<strong>{key}</strong>
1966
</h3>
20-
{value.map((annotation, idx) => (
21-
<div key={idx}>
22-
{annotation.error && (
23-
<div>
24-
<h3 className="ml-1">Error</h3>
25-
<Preview value={annotation["error"]} />{" "}
26-
</div>
27-
)}
28-
{annotation.text && (
29-
<div>
30-
<h3 className="ml-1">Text</h3>
31-
<Preview value={annotation["text"]} />{" "}
32-
</div>
33-
)}
34-
{annotation.media_object && (
35-
<MediaObjectDisplay
36-
mediaObject={annotation["media_object"]}
37-
/>
38-
)}
39-
</div>
40-
))}
67+
{Array.isArray(value)
68+
? listAnnotationDisplay(value)
69+
: dictAnnotationDisplay(value)}
4170
</div>
4271
))
4372
: null}

helm-frontend/src/utils/getReleaseUrl.ts

+3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ export default function getReleaseUrl(
55
if (!currProjectId) {
66
return "#";
77
}
8+
if (currProjectId === "home") {
9+
return `https://crfm.stanford.edu/helm/`;
10+
}
811
if (!version) {
912
return `https://crfm.stanford.edu/helm/${currProjectId}/latest/`;
1013
}

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ PyWavelets==1.4.1
238238
PyYAML==6.0.1
239239
referencing==0.35.1
240240
regex==2024.5.10
241+
reka-api==2.0.0
241242
requests==2.31.0
242243
requests-oauthlib==2.0.0
243244
retrying==1.3.4
+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Reads all runs from the suite and writes them to the CSV folder in CSV format.
2+
3+
EXPERIMENTAL: Not for public use.
4+
TEMPORARY: Delete after 2024-09-30"""
5+
6+
import argparse
7+
import csv
8+
import os
9+
import re
10+
11+
from tqdm import tqdm
12+
13+
from helm.benchmark.adaptation.scenario_state import ScenarioState
14+
from helm.common.codec import from_json
15+
from helm.common.general import ensure_directory_exists
16+
17+
18+
class FieldNames:
19+
CATEGORY_ID = "cate-idx"
20+
L2_NAME = "l2-name"
21+
L3_NAME = "l3-name"
22+
L4_NAME = "l4-name"
23+
PROMPT = "prompt"
24+
RESPONSE = "response"
25+
JUDGE_PROMPT = "judge_prompt"
26+
SCORE_REASON = "score_reason"
27+
SCORE = "score"
28+
29+
30+
def process_one(scenario_state_path: str, csv_file_path: str):
31+
with open(scenario_state_path) as f:
32+
scenario_state = from_json(f.read(), ScenarioState)
33+
34+
fieldnames = [
35+
FieldNames.CATEGORY_ID,
36+
FieldNames.L2_NAME,
37+
FieldNames.L3_NAME,
38+
FieldNames.L4_NAME,
39+
FieldNames.PROMPT,
40+
FieldNames.RESPONSE,
41+
FieldNames.JUDGE_PROMPT,
42+
FieldNames.SCORE_REASON,
43+
FieldNames.SCORE,
44+
]
45+
with open(csv_file_path, "w", newline="") as output_file:
46+
writer = csv.DictWriter(output_file, fieldnames=fieldnames)
47+
writer.writeheader()
48+
for request_state in scenario_state.request_states:
49+
row = {}
50+
references = request_state.instance.references
51+
assert len(references) == 4
52+
row[FieldNames.CATEGORY_ID] = references[0].output.text
53+
row[FieldNames.L2_NAME] = references[1].output.text
54+
row[FieldNames.L3_NAME] = references[2].output.text
55+
row[FieldNames.L4_NAME] = references[3].output.text
56+
row[FieldNames.PROMPT] = request_state.request.prompt
57+
assert request_state.result
58+
assert len(request_state.result.completions) == 1
59+
row[FieldNames.RESPONSE] = request_state.result.completions[0].text
60+
assert request_state.annotations
61+
row[FieldNames.JUDGE_PROMPT] = request_state.annotations["air_bench_2024"]["prompt_text"]
62+
row[FieldNames.SCORE_REASON] = request_state.annotations["air_bench_2024"]["reasoning"]
63+
row[FieldNames.SCORE] = request_state.annotations["air_bench_2024"]["score"]
64+
writer.writerow(row)
65+
print(f"Wrote {csv_file_path}")
66+
67+
68+
def process_all(suite_path: str, csv_path: str):
69+
ensure_directory_exists(csv_path)
70+
run_dir_names = sorted([p for p in os.listdir(suite_path) if p.startswith("air_bench_2024:")])
71+
for run_dir_name in tqdm(run_dir_names, disable=None):
72+
scenario_state_path = os.path.join(suite_path, run_dir_name, "scenario_state.json")
73+
if not os.path.isfile(scenario_state_path):
74+
continue
75+
model_name_match = re.search("model=([A-Za-z0-9_-]+)", run_dir_name)
76+
assert model_name_match
77+
model_name = model_name_match[1]
78+
csv_file_path = os.path.join(csv_path, f"{model_name}_result.csv")
79+
process_one(scenario_state_path, csv_file_path)
80+
81+
82+
def main():
83+
parser = argparse.ArgumentParser()
84+
parser.add_argument(
85+
"-o",
86+
"--output-path",
87+
type=str,
88+
help="Where the benchmarking output lives",
89+
default="benchmark_output",
90+
)
91+
parser.add_argument(
92+
"--csv-path",
93+
type=str,
94+
help="Name of the CSV folder.",
95+
default="csv_output",
96+
)
97+
parser.add_argument(
98+
"--suite",
99+
type=str,
100+
help="Name of the suite.",
101+
required=True,
102+
)
103+
args = parser.parse_args()
104+
suite_path = os.path.join(args.output_path, "runs", args.suite)
105+
process_all(suite_path, args.csv_path)
106+
107+
108+
if __name__ == "__main__":
109+
main()

setup.cfg

+7
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,15 @@ models =
161161
crfm-helm[google]
162162
crfm-helm[mistral]
163163
crfm-helm[openai]
164+
crfm-helm[reka]
164165
crfm-helm[together]
165166
crfm-helm[tsinghua]
166167
crfm-helm[yandex]
167168
crfm-helm[openvino]
168169

170+
reka =
171+
reka-api~=2.0.0
172+
169173
vlm =
170174
crfm-helm[openai]
171175

@@ -182,6 +186,9 @@ vlm =
182186
scipy~=1.10
183187
torchvision>=0.14.1,<3.0.0
184188

189+
# For Reka AI
190+
crfm-helm[reka]
191+
185192
# VLM scenarios
186193
crfm-helm[images]
187194
crfm-helm[image2structure]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import datasets
2+
import os
3+
import re
4+
from typing import Any
5+
6+
from helm.common.general import ensure_directory_exists
7+
from helm.benchmark.adaptation.request_state import RequestState
8+
from helm.benchmark.annotation.annotator import Annotator
9+
from helm.clients.auto_client import AutoClient
10+
from helm.common.request import Request
11+
12+
13+
class AIRBench2024Annotator(Annotator):
14+
"""The AIRBench 2024 autograder."""
15+
16+
name = "air_bench_2024"
17+
18+
def __init__(self, auto_client: AutoClient, file_storage_path: str):
19+
self._auto_client = auto_client
20+
cache_dir = os.path.join(file_storage_path, "data")
21+
ensure_directory_exists(cache_dir)
22+
dataset = datasets.load_dataset(
23+
"yifanmai/air-bench-2024-beta", "judge_prompts", split="test", cache_dir=cache_dir
24+
)
25+
self._category_id_to_judge_prompt = {row["cate-idx"]: row["judge_prompt"] for row in dataset}
26+
# Regex pattern is lenient to allow for typos e.g. extra whitespace
27+
self._pattern = re.compile("##\s*short_reasoning\s*:(.*)##\s*the_score\s*:(.*)", re.DOTALL)
28+
29+
def annotate(self, request_state: RequestState) -> Any:
30+
assert request_state.result
31+
assert len(request_state.result.completions) == 1
32+
model_input_text = request_state.request.prompt
33+
model_output_text = request_state.result.completions[0].text
34+
if not model_output_text.strip():
35+
return {"prompt_text": "", "reasoning": "BLOCKED_REQUEST_OR_EMPTY_RESPONSE", "score": 0.0}
36+
category_id = request_state.instance.references[0].output.text
37+
prompt_template = self._category_id_to_judge_prompt[category_id]
38+
# Strip to deal with incorrectly formatted input CSV.
39+
# TODO: Stop stripping after CSV is fixed.
40+
annotator_prompt = prompt_template.replace("{{QUESTION}}", model_input_text).replace(
41+
"{{ANSWER}}", model_output_text
42+
)
43+
annotator_request = Request(
44+
model="openai/gpt-4o-2024-05-13",
45+
model_deployment="openai/gpt-4o-2024-05-13",
46+
prompt=annotator_prompt,
47+
temperature=0.0,
48+
max_tokens=64,
49+
)
50+
annotator_response = self._auto_client.make_request(annotator_request)
51+
if not annotator_response.success:
52+
raise Exception(f"Annotation request failed: {annotator_response.error}")
53+
assert len(annotator_response.completions) == 1
54+
annotator_response_text = annotator_response.completions[0].text
55+
annotator_response_parts = self._pattern.search(annotator_response_text)
56+
if not annotator_response_parts:
57+
raise Exception(f"Malformed annotator response: {annotator_response_text}")
58+
reasoning = annotator_response_parts[1].strip()
59+
try:
60+
score = float(annotator_response_parts[2].strip())
61+
except ValueError as e:
62+
raise Exception(f"Malformed annotator response: {annotator_response_text}") from e
63+
64+
return {"prompt_text": annotator_prompt, "reasoning": reasoning, "score": score}

src/helm/benchmark/annotation/annotator_factory.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from typing import Any, Dict, Mapping, Optional
33

4+
from helm.clients.auto_client import AutoClient
45
from helm.common.credentials_utils import provide_api_key
56
from helm.common.cache_backend_config import CacheBackendConfig, CacheConfig
67
from helm.common.hierarchical_logger import hlog
@@ -46,6 +47,11 @@ def get_annotator(self, annotator_spec: AnnotatorSpec) -> Annotator:
4647
provider_bindings={
4748
"api_key": lambda: provide_api_key(self.credentials, annotator_name),
4849
"file_storage_path": lambda: self._get_file_storage_path(annotator_name),
50+
"auto_client": lambda: AutoClient(
51+
credentials=self.credentials,
52+
file_storage_path=self.file_storage_path,
53+
cache_backend_config=self.cache_backend_config,
54+
),
4955
},
5056
)
5157
annotator = create_object(annotator_spec)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import math
2+
import json
3+
from typing import List, Union
4+
5+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
6+
from helm.benchmark.adaptation.request_state import RequestState
7+
from helm.benchmark.metrics.metric import Metric
8+
from helm.benchmark.metrics.metric_name import MetricName
9+
from helm.benchmark.metrics.metric_service import MetricService
10+
from helm.benchmark.metrics.statistic import Stat
11+
from helm.benchmark.metrics.fin_qa_metrics_helper import ( # type: ignore
12+
equal_program,
13+
eval_program,
14+
program_tokenization,
15+
)
16+
17+
18+
def _get_program_accuracy(reference_program: List[str], generated_program: List[str]) -> float:
19+
return 1.0 if equal_program(reference_program, generated_program) else 0.0
20+
21+
22+
def _get_execution_accuracy(reference_execution: str, generated_program: List[str], table: List[List[str]]) -> float:
23+
invalid_flag: int
24+
generated_result: Union[str, float]
25+
invalid_flag, generated_result = eval_program(generated_program, table)
26+
if invalid_flag:
27+
return 0.0
28+
if reference_execution == "yes" or reference_execution == "no":
29+
return 1.0 if reference_execution == generated_result else 0
30+
else:
31+
if not isinstance(generated_result, float):
32+
return 0.0
33+
return 1.0 if math.isclose(float(reference_execution), generated_result) else 0
34+
35+
36+
class FinQAMetric(Metric):
37+
def evaluate_generation(
38+
self,
39+
adapter_spec: AdapterSpec,
40+
request_state: RequestState,
41+
metric_service: MetricService,
42+
eval_cache_path: str,
43+
) -> List[Stat]:
44+
assert len(request_state.instance.references) == 3
45+
reference_text = request_state.instance.references[0].output.text
46+
reference_program = program_tokenization(reference_text)
47+
reference_execution = request_state.instance.references[1].output.text
48+
table: List[List[str]] = json.loads(request_state.instance.references[2].output.text)
49+
50+
assert request_state.result
51+
assert len(request_state.result.completions) == 1
52+
generated_text = request_state.result.completions[0].text.strip()
53+
generated_program = program_tokenization(generated_text)
54+
55+
return [
56+
Stat(MetricName("program_accuracy")).add(_get_program_accuracy(reference_program, generated_program)),
57+
Stat(MetricName("execution_accuracy")).add(
58+
_get_execution_accuracy(reference_execution, generated_program, table)
59+
),
60+
]

0 commit comments

Comments
 (0)