|
| 1 | +import json |
| 2 | +import os |
| 3 | +from typing import Dict, List, Any |
| 4 | + |
| 5 | +from helm.benchmark.scenarios.scenario import ( |
| 6 | + Input, |
| 7 | + Scenario, |
| 8 | + Instance, |
| 9 | + Reference, |
| 10 | + TRAIN_SPLIT, |
| 11 | + VALID_SPLIT, |
| 12 | + CORRECT_TAG, |
| 13 | + Output, |
| 14 | +) |
| 15 | +from helm.common.general import ensure_file_downloaded |
| 16 | + |
| 17 | + |
| 18 | +class ConvFinQACalcScenario(Scenario): |
| 19 | + """A mathematical calculation benchmark based on ConvFinQA. |
| 20 | +
|
| 21 | + Data source: |
| 22 | + https://github.com/czyssrs/ConvFinQA |
| 23 | +
|
| 24 | + Reference: |
| 25 | + Zhiyu Chen, Shiyang Li, Charese Smiley, Zhiqiang Ma, Sameena Shah, and William Yang Wang. 2022. |
| 26 | + ConvFinQA: Exploring the Chain of Numerical Reasoning in Conversational Finance Question Answering. |
| 27 | + In Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing, |
| 28 | + pages 6279–6292, Abu Dhabi, United Arab Emirates. Association for Computational Linguistics. |
| 29 | + https://aclanthology.org/2022.emnlp-main.421 |
| 30 | + """ # noqa: E501 |
| 31 | + |
| 32 | + name = "conv_fin_qa_calc" |
| 33 | + description = "A mathematical calculation benchmark based on ConvFinQA: Exploring the Chain of Numerical Reasoning in Conversational Finance Question Answering [(Chen ey al., 2022)](https://arxiv.org/pdf/2210.03849.pdf)." # noqa: E501 |
| 34 | + tags = ["question_answering", "finance"] |
| 35 | + |
| 36 | + DATASET_DOWNLOAD_URL: str = ( |
| 37 | + "https://github.com/czyssrs/ConvFinQA/raw/cf3eed2d5984960bf06bb8145bcea5e80b0222a6/data.zip" |
| 38 | + ) |
| 39 | + |
| 40 | + _SPLIT_TO_JSON_FILE_NAME: Dict[str, str] = {TRAIN_SPLIT: "train_turn.json", VALID_SPLIT: "dev_turn.json"} |
| 41 | + |
| 42 | + def make_pseudo_markdown_table(self, table: List[List[Any]], sep: str = "\n") -> str: |
| 43 | + markdown_lines: List[str] = [] |
| 44 | + |
| 45 | + for row in table: |
| 46 | + row_inner_markdown = " | ".join([str(cell) for cell in row]) |
| 47 | + row_markdown = f"| {row_inner_markdown} |" |
| 48 | + markdown_lines.append(row_markdown) |
| 49 | + |
| 50 | + return sep.join(markdown_lines) |
| 51 | + |
| 52 | + def convert_to_instance(self, dic: Dict[str, Any], split: str, sep: str = "\n") -> Instance: |
| 53 | + linearized_table = self.make_pseudo_markdown_table(dic["table"]) |
| 54 | + input_text = f"Table: {sep}{linearized_table}{sep}{sep}" |
| 55 | + |
| 56 | + if "gold_ind" in dic["annotation"]: |
| 57 | + facts = dic["annotation"]["gold_ind"] |
| 58 | + elif "gold_inds" in dic["annotation"]: |
| 59 | + facts = dic["annotation"]["gold_inds"] |
| 60 | + else: |
| 61 | + facts = {} |
| 62 | + table_text = "" |
| 63 | + for fact_type, fact in facts.items(): |
| 64 | + if "text" in fact_type: |
| 65 | + table_text += fact |
| 66 | + if table_text: |
| 67 | + input_text += f"Text: {sep}{table_text}{sep}{sep}" |
| 68 | + |
| 69 | + for ind, q in enumerate(dic["annotation"]["cur_dial"]): |
| 70 | + if ind < len(dic["annotation"]["cur_dial"]) - 1: |
| 71 | + input_text += f"Question: {q}{sep}Answer: {dic['annotation']['exe_ans_list'][ind]}{sep}" |
| 72 | + else: |
| 73 | + input_text += f"Question: {q}" |
| 74 | + |
| 75 | + answer = str(dic["annotation"]["exe_ans"]) |
| 76 | + return Instance( |
| 77 | + input=Input(text=input_text), |
| 78 | + references=[Reference(Output(text=answer), tags=[CORRECT_TAG])], |
| 79 | + split=split, |
| 80 | + ) |
| 81 | + |
| 82 | + def get_instances(self, output_path: str) -> List[Instance]: |
| 83 | + data_path = os.path.join(output_path, "data") |
| 84 | + ensure_file_downloaded( |
| 85 | + source_url=self.DATASET_DOWNLOAD_URL, |
| 86 | + target_path=os.path.join(output_path, "data"), |
| 87 | + unpack=True, |
| 88 | + unpack_type="unzip", |
| 89 | + ) |
| 90 | + instances: List[Instance] = [] |
| 91 | + for split, json_file_name in self._SPLIT_TO_JSON_FILE_NAME.items(): |
| 92 | + json_file_path = os.path.join(data_path, json_file_name) |
| 93 | + with open(json_file_path) as f: |
| 94 | + raw_instances = json.load(f) |
| 95 | + for raw_instance in raw_instances: |
| 96 | + instances.append(self.convert_to_instance(raw_instance, split)) |
| 97 | + return instances |
0 commit comments