Skip to content

Commit

Permalink
add lazy IteratorVariable implementations for map and zip (#131413)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch/pytorch#130750.

Repro of lazy/eager `map` discrepancy without `islice`:
```python
    def fn(a, b):
        y = 1

        def f(x):
            nonlocal y
            y += 1
            return x

        l = list(zip([a, b], map(f, [1, 2, 3, 4])))
        return a + y
```

The major change is that we implement `MapVariable` and `ZipVariable` based on `IteratorVariable`. Before, `map` and `zip` were being traced by immediately unpacking the result as a `TupleVariable`, which is wrong in cases such as the example above.

`MapVariable`s are not allowed to be unpacked while `ZipVariable`s can only be unpacked if all of its iterables can also be unpacked.

We also add new `[has_]force_unpack_var_sequence` methods to `VariableTracker` for the case where it is safe to unpack the entire sequence lazily, e.g., when building a list from a map (i.e. `list(map(f, ...))`).

X-link: pytorch/pytorch#131413
Approved by: https://github.com/anijain2305

Reviewed By: clee2000

Differential Revision: D60322948

Pulled By: williamwen42

fbshipit-source-id: 52f7763d58943696c0ed2abf8fa03fa6795d1be9
  • Loading branch information
williamwen42 authored and facebook-github-bot committed Jul 27, 2024
1 parent dfb15d9 commit 3f8fa49
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,8 +1413,12 @@ def same(
"""Check correctness to see if ref and res match"""
if fp64_ref is None:
fp64_ref = ref
if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}"
if isinstance(
ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size)
):
assert isinstance(
res, (list, tuple, collections.deque)
), f"type mismatch {type(ref)} {type(res)}"
if len(ref) != len(res):
log_error("Length mismatch")
return False
Expand Down

0 comments on commit 3f8fa49

Please sign in to comment.