Skip to content

Commit 4fa7216

Browse files
Feuermagierezyang
authored andcommitted
FlopCounterMode: Decompose ops for inference mode (pytorch#138508)
Fixes pytorch#126268 I've basically followed @ezyang suggestion (I think) to use `func.decompose(...)`. Since `__torch_dispatch__` won't be called a second time for the same op, I've added a second `TorchDispatchMode` (`_DecomposedCounterMode`) that simpy dispatches to the parent flop counter. Using `self` as the inner context manager is not possible, since the second call to `__enter__` would re-initialize the counter's tracking state. Let me know if there's something wrong with this implementation, since I'm quite unsure how the decomposition thing actually works :D Pull Request resolved: pytorch#138508 Approved by: https://github.com/ezyang Co-authored-by: Edward Z. Yang <ezyang@meta.com>
1 parent cffeb83 commit 4fa7216

File tree

2 files changed

+70
-8
lines changed

2 files changed

+70
-8
lines changed

test/test_flop_counter.py

+24
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,30 @@ def formula(*args, **kwargs):
810810
self.assertEqual(called, 1)
811811
self.assertExpectedInline(get_total_flops(mode), """9001""")
812812

813+
@skipIfNoTorchVision
814+
def test_inference_mode(self):
815+
def get_flops(model):
816+
with FlopCounterMode(model) as mode:
817+
a = T(1, 3, 224, 224)
818+
model(a).sum()
819+
return mode
820+
821+
resnet18 = torchvision_models.resnet18()
822+
823+
mode_standard = get_flops(resnet18)
824+
825+
with torch.inference_mode():
826+
mode_inference = get_flops(resnet18)
827+
828+
self.assertEqual(get_total_flops(mode_standard), get_total_flops(mode_inference))
829+
830+
layer1_conv_flops_standard = mode_standard.flop_counts["ResNet.layer1"][
831+
torch.ops.aten.convolution
832+
]
833+
layer1_conv_flops_inference = mode_inference.flop_counts["ResNet.layer1"][
834+
torch.ops.aten.convolution
835+
]
836+
self.assertEqual(layer1_conv_flops_standard, layer1_conv_flops_inference)
813837

814838
if __name__ == "__main__":
815839
run_tests()

torch/utils/flop_counter.py

+46-8
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def nf(args):
593593
return nf
594594

595595

596-
class FlopCounterMode(TorchDispatchMode):
596+
class FlopCounterMode:
597597
"""
598598
``FlopCounterMode`` is a context manager that counts the number of flops within its context.
599599
@@ -623,6 +623,7 @@ def __init__(
623623
self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int))
624624
self.depth = depth
625625
self.display = display
626+
self.mode: Optional[_FlopCounterMode] = None
626627
if custom_mapping is None:
627628
custom_mapping = {}
628629
if mods is not None:
@@ -708,22 +709,22 @@ def process_mod(mod_name, depth):
708709

709710
return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right"))
710711

712+
# NB: This context manager is NOT reentrant
711713
def __enter__(self):
712714
self.flop_counts.clear()
713715
self.mod_tracker.__enter__()
714-
super().__enter__()
716+
self.mode = _FlopCounterMode(self)
717+
self.mode.__enter__()
715718
return self
716719

717720
def __exit__(self, *args):
718-
super().__exit__(*args)
721+
assert self.mode is not None
722+
b = self.mode.__exit__(*args)
723+
self.mode = None # break cycles
719724
self.mod_tracker.__exit__()
720725
if self.display:
721726
print(self.get_table(self.depth))
722-
723-
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
724-
kwargs = kwargs if kwargs else {}
725-
out = func(*args, **kwargs)
726-
return self._count_flops(func._overloadpacket, out, args, kwargs)
727+
return b
727728

728729
def _count_flops(self, func_packet, out, args, kwargs):
729730
if func_packet in self.flop_registry:
@@ -733,3 +734,40 @@ def _count_flops(self, func_packet, out, args, kwargs):
733734
self.flop_counts[par][func_packet] += flop_count
734735

735736
return out
737+
738+
739+
class _FlopCounterMode(TorchDispatchMode):
740+
def __init__(self, counter: FlopCounterMode):
741+
self.counter = counter
742+
743+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
744+
kwargs = kwargs if kwargs else {}
745+
746+
# Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
747+
if func in {torch.ops.aten.is_contiguous.default,
748+
torch.ops.aten.is_contiguous.memory_format,
749+
torch.ops.aten.is_strides_like_format.default,
750+
torch.ops.aten.is_non_overlapping_and_dense.default,
751+
torch.ops.aten.size.default,
752+
torch.ops.aten.sym_size.default,
753+
torch.ops.aten.stride.default,
754+
torch.ops.aten.sym_stride.default,
755+
torch.ops.aten.storage_offset.default,
756+
torch.ops.aten.sym_storage_offset.default,
757+
torch.ops.aten.numel.default,
758+
torch.ops.aten.sym_numel.default,
759+
torch.ops.aten.dim.default,
760+
torch.ops.prim.layout.default}:
761+
762+
return NotImplemented
763+
764+
# If we don't have func in flop_registry, see if it can decompose
765+
if func not in self.counter.flop_registry and func is not torch.ops.prim.device.default:
766+
with self:
767+
r = func.decompose(*args, **kwargs)
768+
if r is not NotImplemented:
769+
return r
770+
771+
# no further decomposition; execute & count flops
772+
out = func(*args, **kwargs)
773+
return self.counter._count_flops(func._overloadpacket, out, args, kwargs)

0 commit comments

Comments
 (0)