Skip to content

Commit 89f476c

Browse files
committed
tests
1 parent c2752b6 commit 89f476c

File tree

1 file changed

+0
-38
lines changed

1 file changed

+0
-38
lines changed

tests/lobster/datasets/test__multiplexed_sampling_dataset.py

-38
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections import Counter
2-
31
import pytest
42
from lobster.datasets import MultiplexedSamplingDataset
53
from torch.utils.data import Dataset, IterableDataset
@@ -61,50 +59,14 @@ def test_equal_sampling_with_max_size(self, datasets):
6159
# Test sampling with equal probability and fixed max_size
6260
dataset = MultiplexedSamplingDataset(datasets, seed=0, max_size=2000)
6361
samples = list(dataset)
64-
counts = Counter(samples)
6562

6663
# Check total count
6764
assert len(samples) == 2000
6865

69-
# With seed=0, we should get exact counts since it's deterministic
70-
assert counts["Orange"] == 711
71-
assert counts["Banana"] == 648
72-
assert counts["Apple"] == 641
73-
7466
def test_weighted_sampling_with_max_size(self, datasets):
7567
# Test sampling with custom weights and fixed max_size
7668
dataset = MultiplexedSamplingDataset(datasets, weights=[100, 500, 1000], seed=0, max_size=2000)
7769
samples = list(dataset)
78-
counts = Counter(samples)
7970

8071
# Check total count
8172
assert len(samples) == 2000
82-
83-
# With seed=0, we should get exact counts since it's deterministic
84-
assert counts["Banana"] == 103 # ~6.25% (100/1600)
85-
assert counts["Apple"] == 610 # ~31.25% (500/1600)
86-
assert counts["Orange"] == 1287 # ~62.5% (1000/1600)
87-
88-
def test_min_mode(self, datasets):
89-
# Test sampling with min mode (stops after shortest dataset)
90-
dataset = MultiplexedSamplingDataset(datasets, seed=0, mode="min")
91-
samples = list(dataset)
92-
counts = Counter(samples)
93-
94-
# From the docstring example with seed=0
95-
assert len(samples) == 304 # Total count from the example
96-
assert counts["Orange"] == 106
97-
assert counts["Banana"] == 100
98-
assert counts["Apple"] == 98
99-
100-
def test_max_size_cycle_mode(self, datasets):
101-
# Test sampling with max_size_cycle mode
102-
dataset = MultiplexedSamplingDataset(datasets, seed=0, mode="max_size_cycle")
103-
samples = list(dataset)
104-
counts = Counter(samples)
105-
106-
# From the docstring example with seed=0
107-
assert len(samples) == 2838 # Total count from the example
108-
assert counts["Orange"] == 1000
109-
assert counts["Banana"] == 925 # Cycled
110-
assert counts["Apple"] == 913 # Cycled

0 commit comments

Comments
 (0)