You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
FlopCounterMode: Decompose ops for inference mode (pytorch#138508)
Fixespytorch#126268
I've basically followed @ezyang suggestion (I think) to use `func.decompose(...)`. Since `__torch_dispatch__` won't be called a second time for the same op, I've added a second `TorchDispatchMode` (`_DecomposedCounterMode`) that simpy dispatches to the parent flop counter. Using `self` as the inner context manager is not possible, since the second call to `__enter__` would re-initialize the counter's tracking state.
Let me know if there's something wrong with this implementation, since I'm quite unsure how the decomposition thing actually works :D
Pull Request resolved: pytorch#138508
Approved by: https://github.com/ezyang
Co-authored-by: Edward Z. Yang <ezyang@meta.com>
0 commit comments