|
1 |
| -from collections import Counter |
2 |
| - |
3 | 1 | import pytest
|
4 | 2 | from lobster.datasets import MultiplexedSamplingDataset
|
5 | 3 | from torch.utils.data import Dataset, IterableDataset
|
@@ -61,50 +59,14 @@ def test_equal_sampling_with_max_size(self, datasets):
|
61 | 59 | # Test sampling with equal probability and fixed max_size
|
62 | 60 | dataset = MultiplexedSamplingDataset(datasets, seed=0, max_size=2000)
|
63 | 61 | samples = list(dataset)
|
64 |
| - counts = Counter(samples) |
65 | 62 |
|
66 | 63 | # Check total count
|
67 | 64 | assert len(samples) == 2000
|
68 | 65 |
|
69 |
| - # With seed=0, we should get exact counts since it's deterministic |
70 |
| - assert counts["Orange"] == 711 |
71 |
| - assert counts["Banana"] == 648 |
72 |
| - assert counts["Apple"] == 641 |
73 |
| - |
74 | 66 | def test_weighted_sampling_with_max_size(self, datasets):
|
75 | 67 | # Test sampling with custom weights and fixed max_size
|
76 | 68 | dataset = MultiplexedSamplingDataset(datasets, weights=[100, 500, 1000], seed=0, max_size=2000)
|
77 | 69 | samples = list(dataset)
|
78 |
| - counts = Counter(samples) |
79 | 70 |
|
80 | 71 | # Check total count
|
81 | 72 | assert len(samples) == 2000
|
82 |
| - |
83 |
| - # With seed=0, we should get exact counts since it's deterministic |
84 |
| - assert counts["Banana"] == 103 # ~6.25% (100/1600) |
85 |
| - assert counts["Apple"] == 610 # ~31.25% (500/1600) |
86 |
| - assert counts["Orange"] == 1287 # ~62.5% (1000/1600) |
87 |
| - |
88 |
| - def test_min_mode(self, datasets): |
89 |
| - # Test sampling with min mode (stops after shortest dataset) |
90 |
| - dataset = MultiplexedSamplingDataset(datasets, seed=0, mode="min") |
91 |
| - samples = list(dataset) |
92 |
| - counts = Counter(samples) |
93 |
| - |
94 |
| - # From the docstring example with seed=0 |
95 |
| - assert len(samples) == 304 # Total count from the example |
96 |
| - assert counts["Orange"] == 106 |
97 |
| - assert counts["Banana"] == 100 |
98 |
| - assert counts["Apple"] == 98 |
99 |
| - |
100 |
| - def test_max_size_cycle_mode(self, datasets): |
101 |
| - # Test sampling with max_size_cycle mode |
102 |
| - dataset = MultiplexedSamplingDataset(datasets, seed=0, mode="max_size_cycle") |
103 |
| - samples = list(dataset) |
104 |
| - counts = Counter(samples) |
105 |
| - |
106 |
| - # From the docstring example with seed=0 |
107 |
| - assert len(samples) == 2838 # Total count from the example |
108 |
| - assert counts["Orange"] == 1000 |
109 |
| - assert counts["Banana"] == 925 # Cycled |
110 |
| - assert counts["Apple"] == 913 # Cycled |
0 commit comments