Skip to content

Commit 7f44dd3

Browse files
yifanmairyokawajpmtake
authored
Add ConvFinQACalc (#3453)
Co-authored-by: Ryo Kawahara <ryokawa@jp.ibm.com> Co-authored-by: Mikio Takeuchi <mtake@jp.ibm.com>
1 parent d556e18 commit 7f44dd3

File tree

4 files changed

+216
-0
lines changed

4 files changed

+216
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import re
2+
from typing import Any, List
3+
4+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
5+
from helm.benchmark.adaptation.request_state import RequestState
6+
from helm.benchmark.metrics.metric import Metric
7+
from helm.benchmark.metrics.metric_name import MetricName
8+
from helm.benchmark.metrics.metric_service import MetricService
9+
from helm.benchmark.metrics.statistic import Stat
10+
from helm.benchmark.scenarios.scenario import CORRECT_TAG
11+
from helm.common.hierarchical_logger import hlog
12+
13+
14+
def _strip_string(str: str) -> Any:
15+
# from https://stackoverflow.com/a/4703508
16+
numeric_const_pattern = r"[-+]?(?:(?:\d*\.\d+)|(?:\d+\.?))(?:[Ee][+-]?\d+)?"
17+
match = re.search(numeric_const_pattern, str)
18+
if match:
19+
try:
20+
return float(str[match.start() : match.end()])
21+
except Exception:
22+
return None
23+
return None
24+
25+
26+
def float_equiv(str1: str, str2: str, eps: float = 1e-6) -> float:
27+
"""Check if two values have the same float value, up to a small tolerance.
28+
29+
This is the implementation used in the IBM Enterprise Benchmark paper.
30+
31+
Note: This is a "mostly-correct" equality function and does not handle some cases correctly:
32+
33+
- If both values are non-floats, then it will always return 1.0,
34+
regardless of whether strings match.
35+
- If either of both values have different units (e.g. currency symbols,
36+
trailing "M" or "B", trailing %), the values will not be converted to the same
37+
units before comparison.
38+
"""
39+
try:
40+
ss1 = _strip_string(str1)
41+
ss2 = _strip_string(str2)
42+
43+
if ss1 is None or ss2 is None:
44+
hlog("WARNING: float_equiv returning 1.0 because both values are non-floats")
45+
return 0.0
46+
return float(abs(ss1 - ss2) < eps)
47+
except Exception:
48+
return float(str1 == str2)
49+
50+
51+
class ConvFinQACalcMetric(Metric):
52+
"""Score metrics for AIRBench 2024."""
53+
54+
def evaluate_generation(
55+
self,
56+
adapter_spec: AdapterSpec,
57+
request_state: RequestState,
58+
metric_service: MetricService,
59+
eval_cache_path: str,
60+
) -> List[Stat]:
61+
assert request_state.result
62+
assert len(request_state.result.completions) == 1
63+
model_answer = request_state.result.completions[0].text
64+
65+
assert len(request_state.instance.references) == 1
66+
assert len(request_state.instance.references[0].tags) == 1
67+
assert request_state.instance.references[0].tags[0] == CORRECT_TAG
68+
gold_answer = request_state.instance.references[0].output.text
69+
70+
return [
71+
Stat(MetricName("float_equiv")).add(float_equiv(model_answer, gold_answer)),
72+
]

src/helm/benchmark/run_specs/enterprise_run_specs.py

+25
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,31 @@ def get_financial_phrasebank_spec(agreement: int = 50) -> RunSpec:
7575
)
7676

7777

78+
@run_spec_function("conv_fin_qa_calc")
79+
def get_conv_fin_qa_calc_spec() -> RunSpec:
80+
scenario_spec = ScenarioSpec(
81+
class_name="helm.benchmark.scenarios.conv_fin_qa_calc_scenario.ConvFinQACalcScenario", args={}
82+
)
83+
84+
adapter_spec = get_generation_adapter_spec(
85+
instructions="Based on the table, answer the final question. Respond with the answer only, with no additional explanation.", # noqa: E501
86+
input_noun=None,
87+
output_noun="Answer",
88+
)
89+
90+
metric_specs = [
91+
MetricSpec(class_name="helm.benchmark.metrics.conv_fin_qa_calc_metrics.ConvFinQACalcMetric")
92+
] + get_basic_metric_specs([])
93+
94+
return RunSpec(
95+
name="conv_fin_qa_calc",
96+
scenario_spec=scenario_spec,
97+
adapter_spec=adapter_spec,
98+
metric_specs=metric_specs,
99+
groups=["conv_fin_qa_calc"],
100+
)
101+
102+
78103
# Legal
79104

80105

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import json
2+
import os
3+
from typing import Dict, List, Any
4+
5+
from helm.benchmark.scenarios.scenario import (
6+
Input,
7+
Scenario,
8+
Instance,
9+
Reference,
10+
TRAIN_SPLIT,
11+
VALID_SPLIT,
12+
CORRECT_TAG,
13+
Output,
14+
)
15+
from helm.common.general import ensure_file_downloaded
16+
17+
18+
class ConvFinQACalcScenario(Scenario):
19+
"""A mathematical calculation benchmark based on ConvFinQA.
20+
21+
Data source:
22+
https://github.com/czyssrs/ConvFinQA
23+
24+
Reference:
25+
Zhiyu Chen, Shiyang Li, Charese Smiley, Zhiqiang Ma, Sameena Shah, and William Yang Wang. 2022.
26+
ConvFinQA: Exploring the Chain of Numerical Reasoning in Conversational Finance Question Answering.
27+
In Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing,
28+
pages 6279–6292, Abu Dhabi, United Arab Emirates. Association for Computational Linguistics.
29+
https://aclanthology.org/2022.emnlp-main.421
30+
""" # noqa: E501
31+
32+
name = "conv_fin_qa_calc"
33+
description = "A mathematical calculation benchmark based on ConvFinQA: Exploring the Chain of Numerical Reasoning in Conversational Finance Question Answering [(Chen ey al., 2022)](https://arxiv.org/pdf/2210.03849.pdf)." # noqa: E501
34+
tags = ["question_answering", "finance"]
35+
36+
DATASET_DOWNLOAD_URL: str = (
37+
"https://github.com/czyssrs/ConvFinQA/raw/cf3eed2d5984960bf06bb8145bcea5e80b0222a6/data.zip"
38+
)
39+
40+
_SPLIT_TO_JSON_FILE_NAME: Dict[str, str] = {TRAIN_SPLIT: "train_turn.json", VALID_SPLIT: "dev_turn.json"}
41+
42+
def make_pseudo_markdown_table(self, table: List[List[Any]], sep: str = "\n") -> str:
43+
markdown_lines: List[str] = []
44+
45+
for row in table:
46+
row_inner_markdown = " | ".join([str(cell) for cell in row])
47+
row_markdown = f"| {row_inner_markdown} |"
48+
markdown_lines.append(row_markdown)
49+
50+
return sep.join(markdown_lines)
51+
52+
def convert_to_instance(self, dic: Dict[str, Any], split: str, sep: str = "\n") -> Instance:
53+
linearized_table = self.make_pseudo_markdown_table(dic["table"])
54+
input_text = f"Table: {sep}{linearized_table}{sep}{sep}"
55+
56+
if "gold_ind" in dic["annotation"]:
57+
facts = dic["annotation"]["gold_ind"]
58+
elif "gold_inds" in dic["annotation"]:
59+
facts = dic["annotation"]["gold_inds"]
60+
else:
61+
facts = {}
62+
table_text = ""
63+
for fact_type, fact in facts.items():
64+
if "text" in fact_type:
65+
table_text += fact
66+
if table_text:
67+
input_text += f"Text: {sep}{table_text}{sep}{sep}"
68+
69+
for ind, q in enumerate(dic["annotation"]["cur_dial"]):
70+
if ind < len(dic["annotation"]["cur_dial"]) - 1:
71+
input_text += f"Question: {q}{sep}Answer: {dic['annotation']['exe_ans_list'][ind]}{sep}"
72+
else:
73+
input_text += f"Question: {q}"
74+
75+
answer = str(dic["annotation"]["exe_ans"])
76+
return Instance(
77+
input=Input(text=input_text),
78+
references=[Reference(Output(text=answer), tags=[CORRECT_TAG])],
79+
split=split,
80+
)
81+
82+
def get_instances(self, output_path: str) -> List[Instance]:
83+
data_path = os.path.join(output_path, "data")
84+
ensure_file_downloaded(
85+
source_url=self.DATASET_DOWNLOAD_URL,
86+
target_path=os.path.join(output_path, "data"),
87+
unpack=True,
88+
unpack_type="unzip",
89+
)
90+
instances: List[Instance] = []
91+
for split, json_file_name in self._SPLIT_TO_JSON_FILE_NAME.items():
92+
json_file_path = os.path.join(data_path, json_file_name)
93+
with open(json_file_path) as f:
94+
raw_instances = json.load(f)
95+
for raw_instance in raw_instances:
96+
instances.append(self.convert_to_instance(raw_instance, split))
97+
return instances

src/helm/benchmark/static/schema_enterprise.yaml

+22
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ metrics:
7272
display_name: Weighted F1
7373
description: Weighted F1 score
7474
lower_is_better: false
75+
- name: float_equiv
76+
display_name: Float Equivalence
77+
description: Float Equivalence
78+
lower_is_better: false
7579

7680
############################################################
7781
perturbations: []
@@ -114,6 +118,7 @@ run_groups:
114118
subgroups:
115119
- gold_commodity_news
116120
- financial_phrasebank
121+
- conv_fin_qa_calc
117122

118123
- name: legal_scenarios
119124
display_name: Legal Scenarios
@@ -156,6 +161,23 @@ run_groups:
156161
when: before 2013
157162
language: English
158163

164+
- name: conv_fin_qa_calc
165+
display_name: ConvFinQACalc
166+
description: "A mathematical calculation benchmark based on ConvFinQA: Exploring the Chain of Numerical Reasoning in Conversational Finance Question Answering [(Chen ey al., 2022)](https://arxiv.org/pdf/2210.03849.pdf)."
167+
metric_groups:
168+
- accuracy
169+
- efficiency
170+
- general_information
171+
environment:
172+
main_name: float_equiv
173+
main_split: valid
174+
taxonomy:
175+
task: question answering with numeric reasoning
176+
what: financial reports
177+
who: financial experts
178+
when: 1999 to 2019
179+
language: English
180+
159181
- name: gold_commodity_news
160182
display_name: Gold Commodity News
161183
description: A classification benchmark based on a dataset of human-annotated gold commodity news headlines ([Sinha & Khandait, 2019](https://arxiv.org/abs/2009.04202)).

0 commit comments

Comments
 (0)