7
7
8
8
9
9
class MultiplexedSamplingDataset (IterableDataset ):
10
- """Dataset that samples from multiple datasets according to specified weights."""
11
-
12
10
def __init__ (
13
11
self ,
14
12
datasets : Sequence [Dataset | IterableDataset ],
15
13
weights : Sequence [float | int ] = None ,
16
14
seed : int | None = None ,
17
15
):
16
+ """Dataset that samples from multiple datasets according to specified weights.
17
+
18
+ This dataset implements a weighted sampling strategy across multiple source datasets.
19
+ For each iteration, it randomly selects a source dataset according to the provided
20
+ weights and yields the next item from that dataset. This allows creating custom
21
+ mixing ratios of different data sources without having to physically combine them.
22
+
23
+ Parameters
24
+ ----------
25
+ datasets : Sequence[Dataset | IterableDataset]
26
+ A sequence of datasets to sample from. These can be either map-style
27
+ datasets (implementing __getitem__ and __len__) or iterable-style
28
+ datasets (implementing __iter__).
29
+
30
+ weights : Sequence[float | int], optional
31
+ Relative sampling weights for each dataset. Can be > 1.0.
32
+ If None, equal weights will be assigned to all datasets.
33
+ Must have the same length as datasets.
34
+ Weights will be normalized internally so they sum to 1.0.
35
+ Non-positive weights are not allowed.
36
+
37
+ seed : int or None, optional
38
+ Random seed for reproducible sampling. If None, sampling will not be
39
+ reproducible across runs.
40
+
41
+ Raises
42
+ ------
43
+ ValueError
44
+ If the number of weights doesn't match the number of datasets,
45
+ or if any weight is negative.
46
+
47
+ Notes
48
+ -----
49
+ - If any dataset is exhausted during iteration, the entire iteration will stop.
50
+ - When using this dataset with multiple workers, each worker will sample
51
+ independently with the same weights but potentially different items.
52
+ - Setting a seed ensures reproducible sampling sequences.
53
+
54
+
55
+ Examples
56
+ --------
57
+ from torch.utils.data import IterableDataset
58
+ # Create three simple iterable datasets
59
+ datasets = [
60
+ IterableStringDataset(["Banana"] * 100)
61
+ IterableStringDataset(["Apple"] * 500)
62
+ IterableStringDataset(["Orange"] * 1000)
63
+ ]
64
+
65
+ # Equal weighting (default)
66
+ equal_dataset = MultiplexedSamplingDataset(datasets, seed=42)
67
+ samples = [next(iter(equal_dataset)) for _ in range(6)]
68
+ # Output would be a mix of fruits with roughly equal probability
69
+ # Note that it **doesn't** take the number of items in each dataset into account
70
+ # ['Banana', 'Orange', 'Apple', 'Orange', 'Banana','Apple']
71
+
72
+ # Custom weighting (99% bananas)
73
+ banana_heavy = MultiplexedSamplingDataset(
74
+ datasets,
75
+ weights=[0.99, 0.005, 0.005],
76
+ seed=42
77
+ )
78
+ samples = [next(iter(banana_heavy)) for _ in range(6)]
79
+ # Output would be mostly bananas
80
+ # ['Banana', 'Banana', 'Banana', 'Banana', 'Banana', 'Banana',]
81
+
82
+
83
+ """
18
84
if weights is not None :
19
85
if len (datasets ) != len (weights ):
20
86
raise ValueError ("Number of datasets and weights must match" )
@@ -37,8 +103,18 @@ def __init__(
37
103
self .generator = None
38
104
39
105
def __iter__ (self ):
40
- """Iterate over samples from datasets according to weights."""
106
+ """Iterate over samples from datasets according to weights.
107
+
108
+ Yields
109
+ ------
110
+ Any
111
+ Items sampled from the constituent datasets according to the specified weights.
41
112
113
+ Notes
114
+ -----
115
+ The iteration stops when any of the constituent datasets is exhausted,
116
+ even if other datasets still have items available.
117
+ """
42
118
# Create iterators for each dataset
43
119
# Assume each dataset handles worker sharding internally
44
120
iterators = {dataset : iter (dataset ) for dataset in self .datasets }
0 commit comments