From b91caadbebacfc854d0681ae2da025b5c46ad1e0 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 13 Jun 2024 19:46:08 -0700 Subject: [PATCH] Use cuda graphs for benchmarking Summary: Per https://fb.workplace.com/groups/420659799592399/posts/807860500872325/, it's a lot more accurate than using regular non-cudagraph benchmarking. I had to change a bunch of use sites of `metrics.latency` because `do_bench_cudagraph` does not support returning quantiles. Could certainly fix it upstream, but that would take more time + it doesn't really seem like quantiles are that useful in TritonBench anyway. Reviewed By: xuzhao9, sijiac Differential Revision: D58502780 fbshipit-source-id: 8c97b95097f49ece47ce9b1660af60afae8c25e8 --- torchbenchmark/operators/addmm/operator.py | 7 +++-- .../operators/flash_attention/operator.py | 4 +-- torchbenchmark/operators/fp8_gemm/fp8_gemm.py | 4 +-- .../operators/gather_gemv/operator.py | 8 +++--- torchbenchmark/operators/gemm/operator.py | 5 ++-- .../operators/int4_gemm/int4_gemm.py | 4 +-- .../operators/layer_norm/operator.py | 16 +++++------- .../operators/low_mem_dropout/operator.py | 12 ++++++--- torchbenchmark/operators/softmax/operator.py | 9 +++---- torchbenchmark/operators/sum/operator.py | 7 +++-- .../operators/vector_add/operator.py | 7 +++-- torchbenchmark/util/triton_op.py | 26 +++++++++---------- 12 files changed, 52 insertions(+), 57 deletions(-) diff --git a/torchbenchmark/operators/addmm/operator.py b/torchbenchmark/operators/addmm/operator.py index f3d98a7738..97264dd713 100644 --- a/torchbenchmark/operators/addmm/operator.py +++ b/torchbenchmark/operators/addmm/operator.py @@ -113,18 +113,17 @@ def gbps( + (torch.addmm(a, mat1, mat2).numel()) ) numel = numel * a.element_size() / 1e9 - gbps = list(map(lambda x: numel / x * 1e3, metrics.latency)) - return statistics.median(gbps) + return numel / metrics.latency * 1e3 @register_metric() def tflops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics - ) -> List[float]: + ) -> float: _, mat1, mat2 = example_inputs m, k = mat1.size() k, n = mat2.size() flops = m * k * 2 * n - return [flops / x / 1e12 * 1e3 for x in metrics.latency] + return flops / metrics.latency / 1e12 * 1e3 @register_x_val(label="(M, N, K)") def get_x_val(self, example_inputs) -> Tuple[int, int, int]: diff --git a/torchbenchmark/operators/flash_attention/operator.py b/torchbenchmark/operators/flash_attention/operator.py index dfcec1989a..60e04b7bc4 100644 --- a/torchbenchmark/operators/flash_attention/operator.py +++ b/torchbenchmark/operators/flash_attention/operator.py @@ -278,7 +278,7 @@ def sdpa_flash_attention(q, k, v): @register_metric() def tflops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics - ) -> List[float]: + ) -> float: flops_per_matmul = ( 2.0 * self.BATCH * self.H * self.N_CTX * self.N_CTX * self.D_HEAD ) @@ -289,7 +289,7 @@ def tflops( tflops *= 2.5 # 2.0(bwd) + 0.5(recompute) elif self.mode == BenchmarkMode.FWD_BWD: tflops *= 3.5 # 1.0(fwd) + 2.0(bwd) + 0.5(recompute) - return list(map(lambda x: tflops / x * 1e-9, metrics.latency)) + return tflops / metrics.latency * 1e-9 def get_bwd_fn(self, fwd_fn: Callable) -> Callable: o = fwd_fn() diff --git a/torchbenchmark/operators/fp8_gemm/fp8_gemm.py b/torchbenchmark/operators/fp8_gemm/fp8_gemm.py index 5348e3b5bb..2274a99a2f 100644 --- a/torchbenchmark/operators/fp8_gemm/fp8_gemm.py +++ b/torchbenchmark/operators/fp8_gemm/fp8_gemm.py @@ -92,7 +92,7 @@ def nbytes(t): m, k = a.shape _, n = b.shape gb = (nbytes(a) + nbytes(b) + nbytes(c)) / 1e9 - return list(map(lambda x: gb / x * 1e3, metrics.latency)) + return gb / metrics.latency * 1e3 @register_metric() def tflops( @@ -102,7 +102,7 @@ def tflops( m, k = a.size() _, n = b.size() flops = 2 * m * n * k - return [flops / x / 1e12 * 1e3 for x in metrics.latency] + return flops / metrics.latency / 1e12 * 1e3 def plot(self): @triton.testing.perf_report( diff --git a/torchbenchmark/operators/gather_gemv/operator.py b/torchbenchmark/operators/gather_gemv/operator.py index c7cb069a8f..f665e86a6e 100644 --- a/torchbenchmark/operators/gather_gemv/operator.py +++ b/torchbenchmark/operators/gather_gemv/operator.py @@ -24,19 +24,19 @@ from .triton_gather_gemv import triton_gemv_0 as triton_test_0 from torch._dynamo.testing import rand_strided + class Operator(BenchmarkOperator): @register_metric() def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): arg0_1, arg1_1, arg2_1 = example_inputs - gbps = ( - lambda ms: 2 + return ( + 2 * arg2_1.size(0) * arg2_1.size(0) * arg0_1.element_size() - / ms + / metrics.latency * 1e-6 ) - return list(map(gbps, metrics.latency)) def __init__(self, mode: str, device: str, extra_args: List[str] = []): super().__init__(mode=mode, device=device, extra_args=extra_args) diff --git a/torchbenchmark/operators/gemm/operator.py b/torchbenchmark/operators/gemm/operator.py index e4cbe98bf9..54f05a9583 100644 --- a/torchbenchmark/operators/gemm/operator.py +++ b/torchbenchmark/operators/gemm/operator.py @@ -177,8 +177,7 @@ def gbps( a, w, bias = example_inputs numel = a.numel() + w.numel() + (torch.mm(a, w).numel()) numel = numel * a.element_size() / 1e9 - gbps = list(map(lambda x: numel / x * 1e3, metrics.latency)) - return statistics.median(gbps) + return numel / metrics.latency * 1e3 @register_metric(skip_baseline=True) def best_config( @@ -205,7 +204,7 @@ def tflops( flops = m * k * 2 * n + 2 * m * n else: flops = m * k * 2 * n - return [flops / x / 1e12 * 1e3 for x in metrics.latency] + return flops / metrics.latency / 1e12 * 1e3 @staticmethod def _scaled_randn(*args, scale: float, **kwargs) -> torch.Tensor: diff --git a/torchbenchmark/operators/int4_gemm/int4_gemm.py b/torchbenchmark/operators/int4_gemm/int4_gemm.py index d6053b8853..9b4b2c925d 100644 --- a/torchbenchmark/operators/int4_gemm/int4_gemm.py +++ b/torchbenchmark/operators/int4_gemm/int4_gemm.py @@ -95,7 +95,7 @@ def nbytes(t): c = fn() gb = (sum(nbytes(t) for t in (x, scale_and_zero, c)) + nbytes(w) // 8) / 1e9 - return list(map(lambda ms: gb / ms * 1e3, metrics.latency)) + return gb / metrics.latency * 1e3 @register_metric() def tflops( @@ -106,7 +106,7 @@ def tflops( m = B * m _, n = b.size() flops = 2 * m * n * k - return [flops / x / 1e12 * 1e3 for x in metrics.latency] + return flops / metrics.latency / 1e12 * 1e3 def plot(self): @triton.testing.perf_report( diff --git a/torchbenchmark/operators/layer_norm/operator.py b/torchbenchmark/operators/layer_norm/operator.py index 4ea416e170..3e10ea69dc 100644 --- a/torchbenchmark/operators/layer_norm/operator.py +++ b/torchbenchmark/operators/layer_norm/operator.py @@ -67,16 +67,12 @@ def get_x_val(self, args): @register_metric() def gbps(self, fn_name, args, metrics: BenchmarkOperatorMetrics) -> float: x = args[0] - - def gbps(ms): - base = x.numel() * x.element_size() / ms * 1e-6 - return { - Mode.FWD: 2 * base, - Mode.BWD: 3 * base, - Mode.FWD_BWD: 5 * base, - }[self.mode] - - return list(map(gbps, metrics.latency)) + base = x.numel() * x.element_size() / metrics.latency * 1e-6 + return { + Mode.FWD: 2 * base, + Mode.BWD: 3 * base, + Mode.FWD_BWD: 5 * base, + }[self.mode] def plot(self): @triton.testing.perf_report( diff --git a/torchbenchmark/operators/low_mem_dropout/operator.py b/torchbenchmark/operators/low_mem_dropout/operator.py index 9b978601f9..167a68ab7f 100644 --- a/torchbenchmark/operators/low_mem_dropout/operator.py +++ b/torchbenchmark/operators/low_mem_dropout/operator.py @@ -13,13 +13,17 @@ from .kernels import _triton_dropout, _seeded_triton_dropout + class Operator(BenchmarkOperator): @register_metric() def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): - gbps = ( - lambda ms: 3 * example_inputs[1].element_size() * example_inputs[1].numel() / ms * 1e-6 + return ( + 3 + * example_inputs[1].element_size() + * example_inputs[1].numel() + / metrics.latency + * 1e-6 ) - return list(map(gbps, metrics.latency)) @register_metric() def tflops( @@ -27,7 +31,7 @@ def tflops( ): p, a = example_inputs flops = 2 * len(a) - return [flops / x / 1e12 * 1e3 for x in metrics.latency] + return flops / metrics.latency @register_benchmark() def triton_dropout(self, p, x): diff --git a/torchbenchmark/operators/softmax/operator.py b/torchbenchmark/operators/softmax/operator.py index 58ba18c8dc..5c9ba7b1d9 100644 --- a/torchbenchmark/operators/softmax/operator.py +++ b/torchbenchmark/operators/softmax/operator.py @@ -113,15 +113,14 @@ def get_x_val(self, example_inputs) -> int: @register_metric() def gbps( self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics - ) -> List[float]: - gbps = ( - lambda ms: 2 + ) -> float: + return ( + 2 * example_inputs[0].nelement() * example_inputs[0].element_size() * 1e-9 - / (ms * 1e-3) + / (metrics.latency * 1e-3) ) - return list(map(gbps, metrics.latency)) def plot(self): @triton.testing.perf_report( diff --git a/torchbenchmark/operators/sum/operator.py b/torchbenchmark/operators/sum/operator.py index 8033a0cf4c..f6f480adfe 100644 --- a/torchbenchmark/operators/sum/operator.py +++ b/torchbenchmark/operators/sum/operator.py @@ -187,13 +187,12 @@ def input_dims( @register_metric() def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): - gbps = ( - lambda ms: example_inputs[0].element_size() + return ( + example_inputs[0].element_size() * example_inputs[0].numel() - / ms + / metrics.latency * 1e-6 ) - return list(map(gbps, metrics.latency if metrics.latency else [0])) @register_metric(skip_baseline=True) def best_config( diff --git a/torchbenchmark/operators/vector_add/operator.py b/torchbenchmark/operators/vector_add/operator.py index e270bc71ea..6b8564e65a 100644 --- a/torchbenchmark/operators/vector_add/operator.py +++ b/torchbenchmark/operators/vector_add/operator.py @@ -17,14 +17,13 @@ class Operator(BenchmarkOperator): @register_metric() def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): - gbps = ( - lambda ms: 3 + return ( + 3 * example_inputs[0].element_size() * example_inputs[0].numel() - / ms + / metrics.latency * 1e-6 ) - return list(map(gbps, metrics.latency)) @register_benchmark() def triton_add(self, x: torch.Tensor, y: torch.Tensor): diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index 6a012233d5..b6ed0e0468 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -136,9 +136,9 @@ def dump_autotuner_best_config(kernel: triton.runtime.Autotuner) -> str: @dataclass class BenchmarkOperatorMetrics: # latency in ms - latency: Optional[List[float]] + latency: Optional[float] # tflops - tflops: Optional[List[float]] + tflops: Optional[float] # speedup over baseline speedup: Optional[float] # accuracy over baseline @@ -735,13 +735,13 @@ def _init_extra_metrics() -> Dict[str, Any]: if set(["latency", "tflops", "speedup", "compile_time"]) & set( self.required_metrics ): - metrics.latency = triton.testing.do_bench( - fn, - warmup=warmup, - rep=rep, - quantiles=quantiles, - grad_to_none=self.get_grad_to_none(self.example_inputs), - ) + with torch.cuda.stream(torch.cuda.Stream()): + metrics.latency = triton.testing.do_bench_cudagraph( + fn, + rep=rep, + return_mode="median", + grad_to_none=self.get_grad_to_none(self.example_inputs), + ) if "walltime" in self.required_metrics: metrics.walltime = do_bench_walltime( fn, @@ -750,7 +750,7 @@ def _init_extra_metrics() -> Dict[str, Any]: ) if "speedup" in self.required_metrics: metrics.speedup = ( - numpy.median(self.baseline_metrics.latency) / numpy.median(metrics.latency) + self.baseline_metrics.latency / metrics.latency if self.baseline_metrics and self.baseline_metrics.latency else None ) @@ -950,7 +950,7 @@ def compile_time( op_task.run() latency_with_compile = op_task.get_attribute("_latency_with_compile_in_task") del op_task - latency_without_compile = numpy.median(metrics.latency) + latency_without_compile = metrics.latency return latency_with_compile - latency_without_compile def hw_roofline(self) -> float: @@ -984,7 +984,7 @@ def _compile_time_in_task( def tflops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics - ) -> List[float]: + ) -> float: def _get_flops(self, func: Callable) -> float: """By default, use the torch.__dispatch__ based flops counter.""" from torch.utils.flop_counter import FlopCounterMode @@ -1010,4 +1010,4 @@ def work_func(): if not fn in self._op_flops: self._op_flops[fn] = _get_flops(self, fn) op_flops = self._op_flops[fn] - return list(map(lambda x: op_flops / x / 1e12 * 1e3, metrics.latency)) + return op_flops / metrics.latency / 1e12 * 1e3