Skip to content

Commit c7d4866

Browse files
authored
Add HotPotQA and SQuAD scenarios from RULER (#3411)
1 parent 185bc24 commit c7d4866

File tree

4 files changed

+539
-0
lines changed

4 files changed

+539
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from helm.benchmark.adaptation.adapter_spec import ADAPT_GENERATION, AdapterSpec
2+
from helm.benchmark.metrics.common_metric_specs import get_open_ended_generation_metric_specs
3+
from helm.benchmark.run_spec import RunSpec, run_spec_function
4+
from helm.benchmark.scenarios.scenario import ScenarioSpec
5+
6+
7+
@run_spec_function("ruler_hotpotqa")
8+
def get_ruler_hotpotqa_spec() -> RunSpec:
9+
scenario_spec = ScenarioSpec(
10+
class_name="helm.benchmark.scenarios.ruler_qa_scenarios.RULERHotpotQAScenario", args={}
11+
)
12+
13+
adapter_spec = AdapterSpec(
14+
method=ADAPT_GENERATION,
15+
global_prefix="",
16+
global_suffix="",
17+
instructions="",
18+
input_prefix="",
19+
input_suffix="",
20+
reference_prefix="A. ",
21+
reference_suffix="",
22+
output_prefix="",
23+
output_suffix="",
24+
instance_prefix="",
25+
max_train_instances=0,
26+
num_outputs=1,
27+
temperature=0.0,
28+
max_tokens=512, # ?
29+
stop_sequences=[],
30+
)
31+
32+
return RunSpec(
33+
name="ruler_hotpotqa",
34+
scenario_spec=scenario_spec,
35+
adapter_spec=adapter_spec,
36+
metric_specs=get_open_ended_generation_metric_specs(),
37+
groups=["ruler_hotpotqa"],
38+
)
39+
40+
41+
@run_spec_function("ruler_squad")
42+
def get_ruler_squad_spec() -> RunSpec:
43+
scenario_spec = ScenarioSpec(class_name="helm.benchmark.scenarios.ruler_qa_scenarios.RULERSQuADScenario")
44+
45+
adapter_spec = AdapterSpec(
46+
method=ADAPT_GENERATION,
47+
global_prefix="",
48+
global_suffix="",
49+
instructions="",
50+
input_prefix="",
51+
input_suffix="",
52+
reference_prefix="A. ",
53+
reference_suffix="",
54+
output_prefix="",
55+
output_suffix="",
56+
instance_prefix="",
57+
max_train_instances=0,
58+
num_outputs=1,
59+
temperature=0.0,
60+
max_tokens=512, # ?
61+
stop_sequences=[],
62+
)
63+
64+
return RunSpec(
65+
name="ruler_squad",
66+
scenario_spec=scenario_spec,
67+
adapter_spec=adapter_spec,
68+
metric_specs=get_open_ended_generation_metric_specs(),
69+
groups=["ruler_squad"],
70+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# flake8: noqa
2+
# type: ignore
3+
# fmt: off
4+
5+
import json
6+
import random
7+
import re
8+
from typing import Any, List
9+
10+
import numpy as np
11+
from tqdm import tqdm
12+
13+
14+
# The following code is copied verbatim from:
15+
# https://github.com/NVIDIA/RULER/blob/860f2bd5c0430569f5941176f9f97f95e770b3da/scripts/data/synthetic/qa.py
16+
# under the following license:
17+
#
18+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
19+
#
20+
# Licensed under the Apache License, Version 2.0 (the "License");
21+
# you may not use this file except in compliance with the License.
22+
# You may obtain a copy of the License at
23+
#
24+
# http://www.apache.org/licenses/LICENSE-2.0
25+
#
26+
# Unless required by applicable law or agreed to in writing, software
27+
# distributed under the License is distributed on an "AS IS" BASIS,
28+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29+
# See the License for the specific language governing permissions and
30+
# limitations under the License
31+
32+
33+
# Read SQuAD QA dataset
34+
def read_squad(file):
35+
with open(file) as f:
36+
data = json.load(f)
37+
38+
total_docs = [p['context'] for d in data['data'] for p in d['paragraphs']]
39+
total_docs = sorted(list(set(total_docs)))
40+
total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
41+
42+
total_qas = []
43+
for d in data['data']:
44+
more_docs = [total_docs_dict[p['context']] for p in d['paragraphs']]
45+
for p in d['paragraphs']:
46+
for qas in p['qas']:
47+
if not qas['is_impossible']:
48+
total_qas.append({
49+
'query': qas['question'],
50+
'outputs': [a['text'] for a in qas['answers']],
51+
'context': [total_docs_dict[p['context']]],
52+
'more_context': [idx for idx in more_docs if idx != total_docs_dict[p['context']]]
53+
})
54+
55+
return total_qas, total_docs
56+
57+
# Read Hotpot QA dataset
58+
def read_hotpotqa(file):
59+
with open(file) as f:
60+
data = json.load(f)
61+
62+
total_docs = [f"{t}\n{''.join(p)}" for d in data for t, p in d['context']]
63+
total_docs = sorted(list(set(total_docs)))
64+
total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
65+
66+
total_qas = []
67+
for d in data:
68+
total_qas.append({
69+
'query': d['question'],
70+
'outputs': [d['answer']],
71+
'context': [total_docs_dict[f"{t}\n{''.join(p)}"] for t, p in d['context']],
72+
})
73+
74+
return total_qas, total_docs
75+
76+
77+
DOCUMENT_PROMPT = "Document {i}:\n{document}"
78+
79+
def generate_input_output(index, num_docs, template: str, random_seed: int, qas: Any, docs: Any):
80+
curr_q = qas[index]['query']
81+
curr_a = qas[index]['outputs']
82+
curr_docs = qas[index]['context']
83+
curr_more = qas[index].get('more_context', [])
84+
if num_docs < len(docs):
85+
if (num_docs - len(curr_docs)) > len(curr_more):
86+
addition_docs = [i for i, d in enumerate(docs) if i not in curr_docs + curr_more]
87+
all_docs = curr_docs + curr_more + random.sample(addition_docs, max(0, num_docs - len(curr_docs) - len(curr_more)))
88+
else:
89+
all_docs = curr_docs + random.sample(curr_more, num_docs - len(curr_docs))
90+
91+
all_docs = [docs[idx] for idx in all_docs]
92+
else:
93+
all_docs = docs
94+
95+
random.Random(random_seed).shuffle(all_docs)
96+
97+
context = '\n\n'.join([DOCUMENT_PROMPT.format(i=i+1, document=d) for i, d in enumerate(all_docs)])
98+
input_text = template.format(
99+
context=context,
100+
query=curr_q
101+
)
102+
return input_text, curr_a
103+
104+
105+
# The following code has been modified from the original source from:
106+
# https://github.com/NVIDIA/RULER/blob/860f2bd5c0430569f5941176f9f97f95e770b3da/scripts/data/synthetic/qa.py
107+
# under the same Apache 2.0 license included above.
108+
109+
110+
def _text_to_tokens(text: str) -> List[int]:
111+
return re.split(r"\s+", text.strip())
112+
113+
114+
def generate_samples(dataset: str, dataset_path: str, template: str, random_seed: int, pre_samples: int, num_samples: int, tokens_to_generate: int, max_seq_length: int, incremental: int = 10, remove_newline_tab: bool = False):
115+
random.seed(random_seed)
116+
np.random.seed(random_seed)
117+
118+
if dataset == 'squad':
119+
qas, docs = read_squad(dataset_path)
120+
elif dataset == 'hotpotqa':
121+
qas, docs = read_hotpotqa(dataset_path)
122+
else:
123+
raise NotImplementedError(f'{dataset} is not implemented.')
124+
125+
write_jsons = []
126+
tokens_to_generate = tokens_to_generate
127+
128+
# Find the perfect num_docs
129+
num_docs = incremental
130+
131+
total_tokens = 0 # Track the total tokens generated for this example
132+
while total_tokens + tokens_to_generate < max_seq_length :
133+
input_text, answer = generate_input_output(0, num_docs, template=template, random_seed=random_seed, qas=qas, docs=docs)
134+
# Calculate the number of tokens in the example
135+
total_tokens = len(_text_to_tokens(input_text + f' {answer}'))
136+
print(f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}')
137+
if total_tokens + tokens_to_generate > max_seq_length:
138+
num_docs -= incremental
139+
break
140+
141+
num_docs += incremental
142+
if num_docs > len(docs):
143+
num_docs = len(docs)
144+
break
145+
print('Number of documents:', num_docs)
146+
147+
# Generate samples
148+
for index in tqdm(range(num_samples)):
149+
used_docs = num_docs
150+
while(True):
151+
try:
152+
input_text, answer = generate_input_output(index + pre_samples, used_docs, template=template, random_seed=random_seed, qas=qas, docs=docs)
153+
length = len(_text_to_tokens(input_text)) + tokens_to_generate
154+
assert length <= max_seq_length, f"{length} exceeds max_seq_length."
155+
break
156+
except:
157+
if used_docs > incremental:
158+
used_docs -= incremental
159+
160+
if remove_newline_tab:
161+
input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split())
162+
163+
formatted_output = {
164+
"index": index,
165+
"input": input_text,
166+
"outputs": answer,
167+
"length": length
168+
}
169+
write_jsons.append(formatted_output)
170+
171+
return write_jsons
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import os
2+
from typing import List, Optional
3+
4+
from helm.common.general import ensure_directory_exists, ensure_file_downloaded
5+
from helm.benchmark.scenarios.ruler_qa_scenario_helper import generate_samples # type: ignore
6+
from helm.benchmark.scenarios.scenario import (
7+
VALID_SPLIT,
8+
Scenario,
9+
Instance,
10+
Reference,
11+
CORRECT_TAG,
12+
Input,
13+
Output,
14+
)
15+
16+
17+
_DATASET_TO_URL = {
18+
"hotpotqa": "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json",
19+
"squad": "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json",
20+
}
21+
22+
23+
class _RULERQAScenario(Scenario):
24+
name = "ruler_qa"
25+
description = "A QA scenario from Ruler"
26+
tags = ["long_context", "rag"]
27+
28+
_TEMPLATE = """Answer the question based on the given documents. Only give me the answer and do not output any other words.
29+
30+
The following are given documents.
31+
32+
{context}
33+
34+
Answer the question based on the given documents. Only give me the answer and do not output any other words.
35+
36+
Question: {query} Answer:""" # noqa: E501
37+
38+
def __init__(self, dataset: str, max_sequence_length: Optional[int] = None):
39+
super().__init__()
40+
self.dataset = dataset or "hotpotqa"
41+
self.max_sequence_length = max_sequence_length or 32768
42+
43+
def get_instances(self, output_path: str) -> List[Instance]:
44+
data_dir = os.path.join(output_path, "data")
45+
ensure_directory_exists(data_dir)
46+
file_path = os.path.join(data_dir, f"{self.dataset}.json")
47+
url = _DATASET_TO_URL[self.dataset]
48+
ensure_file_downloaded(url, file_path)
49+
instances: List[Instance] = []
50+
samples = generate_samples(
51+
dataset=self.dataset,
52+
dataset_path=file_path,
53+
max_seq_length=self.max_sequence_length,
54+
tokens_to_generate=32,
55+
num_samples=500,
56+
random_seed=42,
57+
pre_samples=0,
58+
template=self._TEMPLATE,
59+
)
60+
for sample in samples:
61+
instance = Instance(
62+
id=sample["index"],
63+
input=Input(text=sample["input"]),
64+
references=[
65+
Reference(Output(text=output_text), tags=[CORRECT_TAG]) for output_text in sample["outputs"]
66+
],
67+
split=VALID_SPLIT,
68+
)
69+
instances.append(instance)
70+
return instances
71+
72+
73+
class RULERHotpotQAScenario(_RULERQAScenario):
74+
name = "ruler_hotpotqa"
75+
description = "The HotpotQA long-context multi-hop RAG question answering scenario from RULER"
76+
tags = ["long_context", "rag"]
77+
78+
def __init__(self, dataset: Optional[str] = None, max_sequence_length: Optional[int] = None):
79+
super().__init__("hotpotqa", max_sequence_length)
80+
81+
82+
class RULERSQuADScenario(_RULERQAScenario):
83+
name = "ruler_squad"
84+
description = "The SQuAD question answering scenario from RULER"
85+
tags = ["long_context", "rag"]
86+
87+
def __init__(self, dataset: Optional[str] = None, max_sequence_length: Optional[int] = None):
88+
super().__init__("squad", max_sequence_length)

0 commit comments

Comments
 (0)