Skip to content

Commit 6f93169

Browse files
committed
docstrings
1 parent 37a1d46 commit 6f93169

File tree

3 files changed

+106
-3
lines changed

3 files changed

+106
-3
lines changed

src/lobster/datasets/_multiplexed_sampling_dataset.py

+79-3
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,80 @@
77

88

99
class MultiplexedSamplingDataset(IterableDataset):
10-
"""Dataset that samples from multiple datasets according to specified weights."""
11-
1210
def __init__(
1311
self,
1412
datasets: Sequence[Dataset | IterableDataset],
1513
weights: Sequence[float | int] = None,
1614
seed: int | None = None,
1715
):
16+
"""Dataset that samples from multiple datasets according to specified weights.
17+
18+
This dataset implements a weighted sampling strategy across multiple source datasets.
19+
For each iteration, it randomly selects a source dataset according to the provided
20+
weights and yields the next item from that dataset. This allows creating custom
21+
mixing ratios of different data sources without having to physically combine them.
22+
23+
Parameters
24+
----------
25+
datasets : Sequence[Dataset | IterableDataset]
26+
A sequence of datasets to sample from. These can be either map-style
27+
datasets (implementing __getitem__ and __len__) or iterable-style
28+
datasets (implementing __iter__).
29+
30+
weights : Sequence[float | int], optional
31+
Relative sampling weights for each dataset. Can be > 1.0.
32+
If None, equal weights will be assigned to all datasets.
33+
Must have the same length as datasets.
34+
Weights will be normalized internally so they sum to 1.0.
35+
Non-positive weights are not allowed.
36+
37+
seed : int or None, optional
38+
Random seed for reproducible sampling. If None, sampling will not be
39+
reproducible across runs.
40+
41+
Raises
42+
------
43+
ValueError
44+
If the number of weights doesn't match the number of datasets,
45+
or if any weight is negative.
46+
47+
Notes
48+
-----
49+
- If any dataset is exhausted during iteration, the entire iteration will stop.
50+
- When using this dataset with multiple workers, each worker will sample
51+
independently with the same weights but potentially different items.
52+
- Setting a seed ensures reproducible sampling sequences.
53+
54+
55+
Examples
56+
--------
57+
from torch.utils.data import IterableDataset
58+
# Create three simple iterable datasets
59+
datasets = [
60+
IterableStringDataset(["Banana"] * 100)
61+
IterableStringDataset(["Apple"] * 500)
62+
IterableStringDataset(["Orange"] * 1000)
63+
]
64+
65+
# Equal weighting (default)
66+
equal_dataset = MultiplexedSamplingDataset(datasets, seed=42)
67+
samples = [next(iter(equal_dataset)) for _ in range(6)]
68+
# Output would be a mix of fruits with roughly equal probability
69+
# Note that it **doesn't** take the number of items in each dataset into account
70+
# ['Banana', 'Orange', 'Apple', 'Orange', 'Banana','Apple']
71+
72+
# Custom weighting (99% bananas)
73+
banana_heavy = MultiplexedSamplingDataset(
74+
datasets,
75+
weights=[0.99, 0.005, 0.005],
76+
seed=42
77+
)
78+
samples = [next(iter(banana_heavy)) for _ in range(6)]
79+
# Output would be mostly bananas
80+
# ['Banana', 'Banana', 'Banana', 'Banana', 'Banana', 'Banana',]
81+
82+
83+
"""
1884
if weights is not None:
1985
if len(datasets) != len(weights):
2086
raise ValueError("Number of datasets and weights must match")
@@ -37,8 +103,18 @@ def __init__(
37103
self.generator = None
38104

39105
def __iter__(self):
40-
"""Iterate over samples from datasets according to weights."""
106+
"""Iterate over samples from datasets according to weights.
107+
108+
Yields
109+
------
110+
Any
111+
Items sampled from the constituent datasets according to the specified weights.
41112
113+
Notes
114+
-----
115+
The iteration stops when any of the constituent datasets is exhausted,
116+
even if other datasets still have items available.
117+
"""
42118
# Create iterators for each dataset
43119
# Assume each dataset handles worker sharding internally
44120
iterators = {dataset: iter(dataset) for dataset in self.datasets}

src/lobster/datasets/_shuffled_iterable_dataset.py

+27
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,33 @@ def __init__(
1010
buffer_size: int = 10000,
1111
seed: int | None = None,
1212
):
13+
"""
14+
A dataset wrapper that applies shuffling to an iterable dataset using a buffer.
15+
16+
This implementation maintains a buffer of items from the underlying dataset
17+
and yields a random item from this buffer each time, replacing it with a new
18+
item from the dataset. This provides approximate shuffling for iterable datasets
19+
that cannot be fully loaded into memory.
20+
21+
Parameters
22+
----------
23+
dataset : IterableDataset
24+
The underlying dataset to shuffle.
25+
buffer_size : int, optional
26+
The size of the buffer used for shuffling, by default 10000.
27+
Larger buffer sizes provide better shuffling at the cost of memory.
28+
seed : int or None, optional
29+
Random seed for reproducibility, by default None.
30+
If None, a random seed will be generated.
31+
32+
Notes
33+
-----
34+
The shuffling is approximate and depends on the buffer size. A larger buffer
35+
provides better shuffling but requires more memory.
36+
37+
This implementation also handles distributed data loading with multiple workers
38+
by ensuring each worker uses a different random seed derived from a shared base seed.
39+
"""
1340
super().__init__()
1441

1542
self.dataset = dataset

0 commit comments

Comments
 (0)