-
Notifications
You must be signed in to change notification settings - Fork 160
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2025-03-06 nightly release (327e225)
- Loading branch information
pytorchbot
committed
Mar 6, 2025
1 parent
ff4396e
commit 4a6f5ca
Showing
3 changed files
with
243 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import itertools | ||
|
||
from parameterized import parameterized | ||
from torch.testing._internal.common_utils import TestCase | ||
from torchdata.nodes import Batcher, Filter | ||
from torchdata.nodes.adapters import IterableWrapper | ||
|
||
from .utils import MockSource, run_test_save_load_state, StatefulRangeNode | ||
|
||
|
||
class TestFilter(TestCase): | ||
def test_filter_basic(self) -> None: | ||
# Test with a simple range | ||
source = IterableWrapper(range(10)) | ||
node = Filter(source, lambda x: x % 2 == 0) # Keep even numbers | ||
|
||
results = list(node) | ||
self.assertEqual(results, [0, 2, 4, 6, 8]) | ||
|
||
# Verify counters | ||
self.assertEqual(node._num_yielded, 5) # 5 even numbers were yielded | ||
self.assertEqual(node._num_filtered, 5) # 5 odd numbers were filtered out | ||
|
||
# Test with a different predicate | ||
source = IterableWrapper(range(10)) | ||
node = Filter(source, lambda x: x > 5) # Keep numbers greater than 5 | ||
|
||
results = list(node) | ||
self.assertEqual(results, [6, 7, 8, 9]) | ||
|
||
# Verify counters | ||
self.assertEqual(node._num_yielded, 4) # 4 numbers > 5 were yielded | ||
self.assertEqual(node._num_filtered, 6) # 6 numbers <= 5 were filtered out | ||
|
||
def test_filter_with_mock_source(self) -> None: | ||
num_samples = 20 | ||
source = MockSource(num_samples=num_samples) | ||
node = Filter(source, lambda x: x["step"] % 3 == 0) # Keep items where step is divisible by 3 | ||
|
||
# Test multi epoch | ||
for epoch in range(2): | ||
node.reset() | ||
results = list(node) | ||
expected_steps = [i for i in range(num_samples) if i % 3 == 0] | ||
self.assertEqual(len(results), len(expected_steps)) | ||
|
||
# Verify counters after each epoch | ||
self.assertEqual(node._num_yielded, len(expected_steps)) | ||
self.assertEqual(node._num_filtered, num_samples - len(expected_steps)) | ||
|
||
for i, result in enumerate(results): | ||
expected_step = expected_steps[i] | ||
self.assertEqual(result["step"], expected_step) | ||
self.assertEqual(result["test_tensor"].item(), expected_step) | ||
self.assertEqual(result["test_str"], f"str_{expected_step}") | ||
|
||
def test_filter_empty_result(self) -> None: | ||
source = IterableWrapper(range(10)) | ||
node = Filter(source, lambda x: x > 100) # No items will pass this filter | ||
|
||
results = list(node) | ||
self.assertEqual(results, []) | ||
|
||
# Verify counters when no items pass the filter | ||
self.assertEqual(node._num_yielded, 0) # No items were yielded | ||
self.assertEqual(node._num_filtered, 10) # All 10 items were filtered out | ||
|
||
@parameterized.expand(itertools.product([0, 3, 7])) | ||
def test_save_load_state(self, midpoint: int): | ||
n = 50 | ||
source = StatefulRangeNode(n=n) | ||
node = Filter(source, lambda x: x["i"] % 3 == 0) # Keep items where 'i' is divisible by 3 | ||
run_test_save_load_state(self, node, midpoint) | ||
|
||
def test_filter_reset_state(self) -> None: | ||
source = IterableWrapper(range(10)) | ||
node = Filter(source, lambda x: x % 2 == 0) | ||
|
||
# Consume first two items | ||
self.assertEqual(next(node), 0) | ||
self.assertEqual(next(node), 2) | ||
|
||
# Check counters after consuming two items | ||
self.assertEqual(node._num_yielded, 2) # 2 even numbers were yielded | ||
self.assertEqual(node._num_filtered, 1) # 1 odd number was filtered out | ||
|
||
# Get state and reset | ||
state = node.state_dict() | ||
node.reset(state) | ||
|
||
# Counters should be preserved after reset with state | ||
self.assertEqual(node._num_yielded, 2) | ||
self.assertEqual(node._num_filtered, 1) | ||
|
||
# Should continue from where we left off | ||
self.assertEqual(next(node), 4) | ||
self.assertEqual(next(node), 6) | ||
self.assertEqual(next(node), 8) | ||
|
||
# Counters should be updated after consuming more items | ||
self.assertEqual(node._num_yielded, 5) # Total of 5 even numbers yielded | ||
self.assertEqual(node._num_filtered, 4) # Total of 4 odd numbers filtered out | ||
|
||
# Should raise StopIteration after all items are consumed | ||
with self.assertRaises(StopIteration): | ||
next(node) | ||
|
||
def test_filter_with_batcher(self) -> None: | ||
# Test Filter node with Batcher | ||
|
||
# Create a source with numbers 0-19 | ||
source = IterableWrapper(range(20)) | ||
|
||
# Batch into groups of 4 | ||
batch_node = Batcher(source, batch_size=4) | ||
|
||
# Filter to keep only batches where the sum is divisible by 10 | ||
filter_node = Filter(batch_node, lambda batch: sum(batch) % 10 == 0) | ||
|
||
# Let's calculate the expected batches and their sums | ||
# Batch 1: [0, 1, 2, 3] -> sum = 6 | ||
# Batch 2: [4, 5, 6, 7] -> sum = 22 | ||
# Batch 3: [8, 9, 10, 11] -> sum = 38 | ||
# Batch 4: [12, 13, 14, 15] -> sum = 54 | ||
# Batch 5: [16, 17, 18, 19] -> sum = 70 | ||
# Batches with sum divisible by 10: Batch 5 (70) | ||
|
||
results = list(filter_node) | ||
|
||
# We expect only one batch to pass the filter (sum divisible by 10) | ||
self.assertEqual(len(results), 1) | ||
self.assertEqual(results[0], [16, 17, 18, 19]) # sum = 70 | ||
|
||
# Check that the filter node tracked both filtered and yielded items | ||
self.assertEqual(filter_node._num_yielded, 1) # 1 batch was yielded | ||
self.assertEqual(filter_node._num_filtered, 4) # 4 batches were filtered out | ||
|
||
# Verify total number of batches processed | ||
self.assertEqual(filter_node._num_yielded + filter_node._num_filtered, 5) # Total of 5 batches | ||
|
||
def test_counter_reset(self) -> None: | ||
# Test that counters are properly reset | ||
source = IterableWrapper(range(10)) | ||
node = Filter(source, lambda x: x % 2 == 0) | ||
|
||
# Consume all items | ||
list(node) | ||
|
||
# Verify counters after first pass | ||
self.assertEqual(node._num_yielded, 5) | ||
self.assertEqual(node._num_filtered, 5) | ||
|
||
# Reset without state | ||
node.reset() | ||
|
||
# Counters should be reset to 0 | ||
self.assertEqual(node._num_yielded, 0) | ||
self.assertEqual(node._num_filtered, 0) | ||
|
||
# Consume some items | ||
next(node) # 0 | ||
next(node) # 2 | ||
|
||
# Verify counters after partial consumption | ||
self.assertEqual(node._num_yielded, 2) | ||
self.assertEqual(node._num_filtered, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import Any, Callable, Dict, Optional, TypeVar | ||
|
||
from torchdata.nodes import BaseNode | ||
|
||
|
||
T = TypeVar("T") | ||
|
||
|
||
class Filter(BaseNode[T]): | ||
"""Node that filters items from source node based on predicate function. | ||
This node applies a filter function to each item from the source node and only yields | ||
items that satisfy the condition (when filter_fn returns True). It keeps track of both | ||
the number of items that were filtered out (rejected) and the number of items that were | ||
yielded (accepted). | ||
Args: | ||
source_node (BaseNode[T]): The source node to filter items from. | ||
filter_fn (Callable[[T], bool]): A function that takes an item and returns True if the item | ||
should be included, False otherwise. | ||
""" | ||
|
||
SOURCE_KEY = "source" | ||
NUM_FILTERED_KEY = "num_filtered" | ||
NUM_YIELDED_KEY = "num_yielded" | ||
|
||
def __init__(self, source_node: BaseNode[T], filter_fn: Callable[[T], bool]): | ||
super().__init__() | ||
self.source = source_node | ||
self.filter_fn = filter_fn | ||
self._num_filtered = 0 # Count of items that did NOT pass the filter | ||
self._num_yielded = 0 # Count of items that DID pass the filter | ||
|
||
def reset(self, initial_state: Optional[Dict[str, Any]] = None): | ||
"""Reset the node to its initial state or to the provided state. | ||
Args: | ||
initial_state: Optional state dictionary to restore from. | ||
""" | ||
super().reset(initial_state) | ||
if initial_state is not None: | ||
self.source.reset(initial_state.get(self.SOURCE_KEY)) | ||
self._num_filtered = initial_state.get(self.NUM_FILTERED_KEY, 0) | ||
self._num_yielded = initial_state.get(self.NUM_YIELDED_KEY, 0) | ||
else: | ||
self.source.reset(None) | ||
self._num_filtered = 0 | ||
self._num_yielded = 0 | ||
|
||
def next(self) -> T: | ||
"""Get the next item that passes the filter. | ||
Returns: | ||
The next item that satisfies the filter condition. | ||
Raises: | ||
StopIteration: If there are no more items in the source node. | ||
""" | ||
while True: | ||
item = next(self.source) | ||
if self.filter_fn(item): | ||
self._num_yielded += 1 | ||
return item | ||
self._num_filtered += 1 | ||
|
||
def get_state(self) -> Dict[str, Any]: | ||
"""Get the current state of the node. | ||
Returns: | ||
A dictionary containing the state of the source node and counters. | ||
""" | ||
return { | ||
self.SOURCE_KEY: self.source.state_dict(), | ||
self.NUM_FILTERED_KEY: self._num_filtered, | ||
self.NUM_YIELDED_KEY: self._num_yielded, | ||
} |