|
| 1 | +# flake8: noqa |
| 2 | +# type: ignore |
| 3 | +# fmt: off |
| 4 | + |
| 5 | +import json |
| 6 | +import random |
| 7 | +import re |
| 8 | +from typing import Any, List |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +from tqdm import tqdm |
| 12 | + |
| 13 | + |
| 14 | +# The following code is copied verbatim from: |
| 15 | +# https://github.com/NVIDIA/RULER/blob/860f2bd5c0430569f5941176f9f97f95e770b3da/scripts/data/synthetic/qa.py |
| 16 | +# under the following license: |
| 17 | +# |
| 18 | +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. |
| 19 | +# |
| 20 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 21 | +# you may not use this file except in compliance with the License. |
| 22 | +# You may obtain a copy of the License at |
| 23 | +# |
| 24 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 25 | +# |
| 26 | +# Unless required by applicable law or agreed to in writing, software |
| 27 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 28 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 29 | +# See the License for the specific language governing permissions and |
| 30 | +# limitations under the License |
| 31 | + |
| 32 | + |
| 33 | +# Read SQuAD QA dataset |
| 34 | +def read_squad(file): |
| 35 | + with open(file) as f: |
| 36 | + data = json.load(f) |
| 37 | + |
| 38 | + total_docs = [p['context'] for d in data['data'] for p in d['paragraphs']] |
| 39 | + total_docs = sorted(list(set(total_docs))) |
| 40 | + total_docs_dict = {c: idx for idx, c in enumerate(total_docs)} |
| 41 | + |
| 42 | + total_qas = [] |
| 43 | + for d in data['data']: |
| 44 | + more_docs = [total_docs_dict[p['context']] for p in d['paragraphs']] |
| 45 | + for p in d['paragraphs']: |
| 46 | + for qas in p['qas']: |
| 47 | + if not qas['is_impossible']: |
| 48 | + total_qas.append({ |
| 49 | + 'query': qas['question'], |
| 50 | + 'outputs': [a['text'] for a in qas['answers']], |
| 51 | + 'context': [total_docs_dict[p['context']]], |
| 52 | + 'more_context': [idx for idx in more_docs if idx != total_docs_dict[p['context']]] |
| 53 | + }) |
| 54 | + |
| 55 | + return total_qas, total_docs |
| 56 | + |
| 57 | +# Read Hotpot QA dataset |
| 58 | +def read_hotpotqa(file): |
| 59 | + with open(file) as f: |
| 60 | + data = json.load(f) |
| 61 | + |
| 62 | + total_docs = [f"{t}\n{''.join(p)}" for d in data for t, p in d['context']] |
| 63 | + total_docs = sorted(list(set(total_docs))) |
| 64 | + total_docs_dict = {c: idx for idx, c in enumerate(total_docs)} |
| 65 | + |
| 66 | + total_qas = [] |
| 67 | + for d in data: |
| 68 | + total_qas.append({ |
| 69 | + 'query': d['question'], |
| 70 | + 'outputs': [d['answer']], |
| 71 | + 'context': [total_docs_dict[f"{t}\n{''.join(p)}"] for t, p in d['context']], |
| 72 | + }) |
| 73 | + |
| 74 | + return total_qas, total_docs |
| 75 | + |
| 76 | + |
| 77 | +DOCUMENT_PROMPT = "Document {i}:\n{document}" |
| 78 | + |
| 79 | +def generate_input_output(index, num_docs, template: str, random_seed: int, qas: Any, docs: Any): |
| 80 | + curr_q = qas[index]['query'] |
| 81 | + curr_a = qas[index]['outputs'] |
| 82 | + curr_docs = qas[index]['context'] |
| 83 | + curr_more = qas[index].get('more_context', []) |
| 84 | + if num_docs < len(docs): |
| 85 | + if (num_docs - len(curr_docs)) > len(curr_more): |
| 86 | + addition_docs = [i for i, d in enumerate(docs) if i not in curr_docs + curr_more] |
| 87 | + all_docs = curr_docs + curr_more + random.sample(addition_docs, max(0, num_docs - len(curr_docs) - len(curr_more))) |
| 88 | + else: |
| 89 | + all_docs = curr_docs + random.sample(curr_more, num_docs - len(curr_docs)) |
| 90 | + |
| 91 | + all_docs = [docs[idx] for idx in all_docs] |
| 92 | + else: |
| 93 | + all_docs = docs |
| 94 | + |
| 95 | + random.Random(random_seed).shuffle(all_docs) |
| 96 | + |
| 97 | + context = '\n\n'.join([DOCUMENT_PROMPT.format(i=i+1, document=d) for i, d in enumerate(all_docs)]) |
| 98 | + input_text = template.format( |
| 99 | + context=context, |
| 100 | + query=curr_q |
| 101 | + ) |
| 102 | + return input_text, curr_a |
| 103 | + |
| 104 | + |
| 105 | +# The following code has been modified from the original source from: |
| 106 | +# https://github.com/NVIDIA/RULER/blob/860f2bd5c0430569f5941176f9f97f95e770b3da/scripts/data/synthetic/qa.py |
| 107 | +# under the same Apache 2.0 license included above. |
| 108 | + |
| 109 | + |
| 110 | +def _text_to_tokens(text: str) -> List[int]: |
| 111 | + return re.split(r"\s+", text.strip()) |
| 112 | + |
| 113 | + |
| 114 | +def generate_samples(dataset: str, dataset_path: str, template: str, random_seed: int, pre_samples: int, num_samples: int, tokens_to_generate: int, max_seq_length: int, incremental: int = 10, remove_newline_tab: bool = False): |
| 115 | + random.seed(random_seed) |
| 116 | + np.random.seed(random_seed) |
| 117 | + |
| 118 | + if dataset == 'squad': |
| 119 | + qas, docs = read_squad(dataset_path) |
| 120 | + elif dataset == 'hotpotqa': |
| 121 | + qas, docs = read_hotpotqa(dataset_path) |
| 122 | + else: |
| 123 | + raise NotImplementedError(f'{dataset} is not implemented.') |
| 124 | + |
| 125 | + write_jsons = [] |
| 126 | + tokens_to_generate = tokens_to_generate |
| 127 | + |
| 128 | + # Find the perfect num_docs |
| 129 | + num_docs = incremental |
| 130 | + |
| 131 | + total_tokens = 0 # Track the total tokens generated for this example |
| 132 | + while total_tokens + tokens_to_generate < max_seq_length : |
| 133 | + input_text, answer = generate_input_output(0, num_docs, template=template, random_seed=random_seed, qas=qas, docs=docs) |
| 134 | + # Calculate the number of tokens in the example |
| 135 | + total_tokens = len(_text_to_tokens(input_text + f' {answer}')) |
| 136 | + print(f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}') |
| 137 | + if total_tokens + tokens_to_generate > max_seq_length: |
| 138 | + num_docs -= incremental |
| 139 | + break |
| 140 | + |
| 141 | + num_docs += incremental |
| 142 | + if num_docs > len(docs): |
| 143 | + num_docs = len(docs) |
| 144 | + break |
| 145 | + print('Number of documents:', num_docs) |
| 146 | + |
| 147 | + # Generate samples |
| 148 | + for index in tqdm(range(num_samples)): |
| 149 | + used_docs = num_docs |
| 150 | + while(True): |
| 151 | + try: |
| 152 | + input_text, answer = generate_input_output(index + pre_samples, used_docs, template=template, random_seed=random_seed, qas=qas, docs=docs) |
| 153 | + length = len(_text_to_tokens(input_text)) + tokens_to_generate |
| 154 | + assert length <= max_seq_length, f"{length} exceeds max_seq_length." |
| 155 | + break |
| 156 | + except: |
| 157 | + if used_docs > incremental: |
| 158 | + used_docs -= incremental |
| 159 | + |
| 160 | + if remove_newline_tab: |
| 161 | + input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split()) |
| 162 | + |
| 163 | + formatted_output = { |
| 164 | + "index": index, |
| 165 | + "input": input_text, |
| 166 | + "outputs": answer, |
| 167 | + "length": length |
| 168 | + } |
| 169 | + write_jsons.append(formatted_output) |
| 170 | + |
| 171 | + return write_jsons |
0 commit comments