Skip to content

Commit

Permalink
linting is done ~
Browse files Browse the repository at this point in the history
  • Loading branch information
keunwoochoi committed Mar 5, 2025
1 parent 003b56e commit bd4ea88
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 36 deletions.
66 changes: 33 additions & 33 deletions test/nodes/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase
from torchdata.nodes import Batcher, Filter
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes import Filter, Batcher

from .utils import MockSource, run_test_save_load_state, StatefulRangeNode

Expand All @@ -13,21 +13,21 @@ def test_filter_basic(self) -> None:
# Test with a simple range
source = IterableWrapper(range(10))
node = Filter(source, lambda x: x % 2 == 0) # Keep even numbers

results = list(node)
self.assertEqual(results, [0, 2, 4, 6, 8])

# Verify counters
self.assertEqual(node._num_yielded, 5) # 5 even numbers were yielded
self.assertEqual(node._num_filtered, 5) # 5 odd numbers were filtered out

# Test with a different predicate
source = IterableWrapper(range(10))
node = Filter(source, lambda x: x > 5) # Keep numbers greater than 5

results = list(node)
self.assertEqual(results, [6, 7, 8, 9])

# Verify counters
self.assertEqual(node._num_yielded, 4) # 4 numbers > 5 were yielded
self.assertEqual(node._num_filtered, 6) # 6 numbers <= 5 were filtered out
Expand All @@ -36,18 +36,18 @@ def test_filter_with_mock_source(self) -> None:
num_samples = 20
source = MockSource(num_samples=num_samples)
node = Filter(source, lambda x: x["step"] % 3 == 0) # Keep items where step is divisible by 3

# Test multi epoch
for epoch in range(2):
node.reset()
results = list(node)
expected_steps = [i for i in range(num_samples) if i % 3 == 0]
self.assertEqual(len(results), len(expected_steps))

# Verify counters after each epoch
self.assertEqual(node._num_yielded, len(expected_steps))
self.assertEqual(node._num_filtered, num_samples - len(expected_steps))

for i, result in enumerate(results):
expected_step = expected_steps[i]
self.assertEqual(result["step"], expected_step)
Expand All @@ -57,10 +57,10 @@ def test_filter_with_mock_source(self) -> None:
def test_filter_empty_result(self) -> None:
source = IterableWrapper(range(10))
node = Filter(source, lambda x: x > 100) # No items will pass this filter

results = list(node)
self.assertEqual(results, [])

# Verify counters when no items pass the filter
self.assertEqual(node._num_yielded, 0) # No items were yielded
self.assertEqual(node._num_filtered, 10) # All 10 items were filtered out
Expand All @@ -69,98 +69,98 @@ def test_filter_empty_result(self) -> None:
def test_save_load_state(self, midpoint: int):
n = 50
source = StatefulRangeNode(n=n)
node = Filter(source, lambda x: x['i'] % 3 == 0) # Keep items where 'i' is divisible by 3
node = Filter(source, lambda x: x["i"] % 3 == 0) # Keep items where 'i' is divisible by 3
run_test_save_load_state(self, node, midpoint)

def test_filter_reset_state(self) -> None:
source = IterableWrapper(range(10))
node = Filter(source, lambda x: x % 2 == 0)

# Consume first two items
self.assertEqual(next(node), 0)
self.assertEqual(next(node), 2)

# Check counters after consuming two items
self.assertEqual(node._num_yielded, 2) # 2 even numbers were yielded
self.assertEqual(node._num_filtered, 1) # 1 odd number was filtered out

# Get state and reset
state = node.state_dict()
node.reset(state)

# Counters should be preserved after reset with state
self.assertEqual(node._num_yielded, 2)
self.assertEqual(node._num_filtered, 1)

# Should continue from where we left off
self.assertEqual(next(node), 4)
self.assertEqual(next(node), 6)
self.assertEqual(next(node), 8)

# Counters should be updated after consuming more items
self.assertEqual(node._num_yielded, 5) # Total of 5 even numbers yielded
self.assertEqual(node._num_filtered, 4) # Total of 4 odd numbers filtered out

# Should raise StopIteration after all items are consumed
with self.assertRaises(StopIteration):
next(node)

def test_filter_with_batcher(self) -> None:
# Test Filter node with Batcher

# Create a source with numbers 0-19
source = IterableWrapper(range(20))

# Batch into groups of 4
batch_node = Batcher(source, batch_size=4)

# Filter to keep only batches where the sum is divisible by 10
filter_node = Filter(batch_node, lambda batch: sum(batch) % 10 == 0)

# Let's calculate the expected batches and their sums
# Batch 1: [0, 1, 2, 3] -> sum = 6
# Batch 2: [4, 5, 6, 7] -> sum = 22
# Batch 3: [8, 9, 10, 11] -> sum = 38
# Batch 4: [12, 13, 14, 15] -> sum = 54
# Batch 5: [16, 17, 18, 19] -> sum = 70
# Batches with sum divisible by 10: Batch 5 (70)

results = list(filter_node)

# We expect only one batch to pass the filter (sum divisible by 10)
self.assertEqual(len(results), 1)
self.assertEqual(results[0], [16, 17, 18, 19]) # sum = 70

# Check that the filter node tracked both filtered and yielded items
self.assertEqual(filter_node._num_yielded, 1) # 1 batch was yielded
self.assertEqual(filter_node._num_filtered, 4) # 4 batches were filtered out

# Verify total number of batches processed
self.assertEqual(filter_node._num_yielded + filter_node._num_filtered, 5) # Total of 5 batches

def test_counter_reset(self) -> None:
# Test that counters are properly reset
source = IterableWrapper(range(10))
node = Filter(source, lambda x: x % 2 == 0)

# Consume all items
list(node)

# Verify counters after first pass
self.assertEqual(node._num_yielded, 5)
self.assertEqual(node._num_filtered, 5)

# Reset without state
node.reset()

# Counters should be reset to 0
self.assertEqual(node._num_yielded, 0)
self.assertEqual(node._num_filtered, 0)

# Consume some items
next(node) # 0
next(node) # 2

# Verify counters after partial consumption
self.assertEqual(node._num_yielded, 2)
self.assertEqual(node._num_filtered, 1)
7 changes: 4 additions & 3 deletions torchdata/nodes/filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, TypeVar, Optional
from typing import Any, Callable, Dict, Optional, TypeVar

from torchdata.nodes import BaseNode


Expand Down Expand Up @@ -28,7 +29,7 @@ def __init__(self, source_node: BaseNode[T], filter_fn: Callable[[T], bool]):
self.source = source_node
self.filter_fn = filter_fn
self._num_filtered = 0 # Count of items that did NOT pass the filter
self._num_yielded = 0 # Count of items that DID pass the filter
self._num_yielded = 0 # Count of items that DID pass the filter

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
"""Reset the node to its initial state or to the provided state.
Expand Down Expand Up @@ -71,5 +72,5 @@ def get_state(self) -> Dict[str, Any]:
return {
self.SOURCE_KEY: self.source.state_dict(),
self.NUM_FILTERED_KEY: self._num_filtered,
self.NUM_YIELDED_KEY: self._num_yielded
self.NUM_YIELDED_KEY: self._num_yielded,
}

0 comments on commit bd4ea88

Please sign in to comment.