Skip to content

Commit 7325bc9

Browse files
authored
fix: Context (tinygrad#1430)
* Fixed issue in Context * Cleaned up fix Now that DEBUG.value = 3 always works we can do so in __new__ as well.
1 parent c08ed19 commit 7325bc9

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

test/test_helpers.py

+7
Original file line numberDiff line numberDiff line change
@@ -99,5 +99,12 @@ def test():
9999
test()
100100
self.assertEqual(VARIABLE.value, 0)
101101

102+
def test_context_exit_reverts_updated_values(self):
103+
D = ContextVar("D", 1)
104+
D.value = 2
105+
with Context(D=3):
106+
...
107+
assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."
108+
102109
if __name__ == '__main__':
103110
unittest.main()

tinygrad/helpers.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ class Context(contextlib.ContextDecorator):
3030
stack: ClassVar[List[dict[str, int]]] = [{}]
3131
def __init__(self, **kwargs): self.kwargs = kwargs
3232
def __enter__(self):
33-
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
34-
Context.stack.append(self.kwargs)
33+
Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state.
34+
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state.
35+
Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
3536
def __exit__(self, *args):
36-
for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, Context.stack[0][k])
37+
for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value)
3738

3839
class ContextVar:
3940
_cache: ClassVar[Dict[str, ContextVar]] = {}
@@ -42,7 +43,7 @@ class ContextVar:
4243
def __new__(cls, key, default_value):
4344
if key in ContextVar._cache: return ContextVar._cache[key]
4445
instance = ContextVar._cache[key] = super().__new__(cls)
45-
instance.value = Context.stack[0][key] = getenv(key, default_value)
46+
instance.value = getenv(key, default_value)
4647
return instance
4748
def __bool__(self): return bool(self.value)
4849
def __ge__(self, x): return self.value >= x

0 commit comments

Comments
 (0)