Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: shuffle data on the generated HF splits #74

Merged
merged 1 commit into from
May 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 46 additions & 34 deletions tti_eval/dataset/types/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datasets import ClassLabel, DatasetDict, Sequence, Value, load_dataset
from datasets import Dataset as _RemoteHFDataset

from tti_eval.dataset import Dataset, Split

Expand Down Expand Up @@ -29,47 +30,58 @@ def set_transform(self, transform):
super().set_transform(transform)
self._dataset.set_transform(transform)

def _get_available_splits(self, **kwargs) -> list[Split]:
datasets: DatasetDict = load_dataset(self.title_in_source, cache_dir=self._cache_dir.as_posix(), **kwargs)
return list(Split(s) for s in datasets.keys() if s in [_ for _ in Split]) + [Split.ALL]
@staticmethod
def _get_available_splits(dataset_dict: DatasetDict) -> list[Split]:
return [Split(s) for s in dataset_dict.keys() if s in [_ for _ in Split]] + [Split.ALL]

def _setup(self, **kwargs):
def _get_hf_dataset_split(self, **kwargs) -> _RemoteHFDataset:
try:
available_splits = self._get_available_splits(**kwargs)
if self.split == Split.ALL: # Retrieve all the dataset data if the split is ALL
return load_dataset(self.title_in_source, split="all", cache_dir=self._cache_dir.as_posix(), **kwargs)
dataset_dict: DatasetDict = load_dataset(
self.title_in_source,
cache_dir=self._cache_dir.as_posix(),
**kwargs,
)
except Exception as e:
raise ValueError(f"Failed to load dataset from Hugging Face: `{self.title_in_source}`") from e

available_splits = HFDataset._get_available_splits(dataset_dict)
missing_splits = [s for s in Split if s not in available_splits]
if self.split == Split.TRAIN and self.split in missing_splits:
# Train split must always exist
raise AttributeError(f"Missing train split in Hugging Face dataset `{self.title_in_source}`")

# Select appropriate HF dataset split
hf_split: str
if self.split == Split.ALL:
hf_split = str(self.split)
elif self.split == Split.TEST:
hf_split = f"train[{85}%:]" if self.split in missing_splits else str(self.split)
elif self.split == Split.VALIDATION:
# The range of the validation split in the training data may vary depending on whether
# the test split is also missing
hf_split = (
f"train[{100 - 15 * len(missing_splits)}%:{100 - 15 * (len(missing_splits) - 1)}%]"
if self.split in missing_splits
else str(self.split)
)
elif self.split == Split.TRAIN:
# Take into account the capacity taken by missing splits
hf_split = f"train[:{100 - 15 * len(missing_splits)}%]"

# Return target dataset if it already exists and won't be modified
if self.split in [Split.VALIDATION, Split.TEST] and self.split in available_splits:
return dataset_dict[self.split]
if self.split == Split.TRAIN:
if self.split in missing_splits:
# Train split must always exist
raise AttributeError(f"Missing train split in Hugging Face dataset: `{self.title_in_source}`")
if not missing_splits:
# No need to split the train dataset, can be returned as a whole
return dataset_dict[self.split]

# Get a 15% of the train dataset for each missing split (VALIDATION, TEST or both)
# This operation includes data shuffling to prevent splits with skewed class counts because of the input order
split_percent = 0.15 * len(missing_splits)
split_seed = 42
# Split the original train dataset into two, the final train dataset and the missing splits dataset
split_to_dataset = dataset_dict["train"].train_test_split(test_size=split_percent, seed=split_seed)
if self.split == Split.TRAIN:
return split_to_dataset[self.split]

if len(missing_splits) == 1:
# One missing split (either VALIDATION or TEST), so we return the 15% stored in "test"
return split_to_dataset["test"]
else:
raise ValueError(f"Unhandled split type `{self.split}`")

self._dataset = load_dataset(
self.title_in_source,
split=hf_split,
cache_dir=self._cache_dir.as_posix(),
**kwargs,
)
# Both VALIDATION and TEST splits are missing
# Each one will take a half of the 30% stored in "test"
if self.split == Split.VALIDATION:
return split_to_dataset["test"].train_test_split(test_size=0.5, seed=split_seed)["train"]
else:
return split_to_dataset["test"].train_test_split(test_size=0.5, seed=split_seed)["test"]

def _setup(self, **kwargs):
self._dataset = self._get_hf_dataset_split(**kwargs)

if self._target_feature not in self._dataset.features:
raise ValueError(f"The dataset `{self.title}` does not have the target feature `{self._target_feature}`")
Expand Down
Loading