1
1
# mypy: allow-untyped-defs
2
2
import builtins
3
3
import contextlib
4
+ import dataclasses
4
5
import functools
5
6
import inspect
6
7
import itertools
15
16
from collections import namedtuple
16
17
from concurrent .futures import as_completed , ThreadPoolExecutor
17
18
from io import StringIO
18
- from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union
19
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type , TypeVar , Union
19
20
from unittest .mock import patch
20
21
21
22
import sympy
@@ -76,6 +77,61 @@ class KernelNamespace:
76
77
extern_kernels = KernelNamespace ()
77
78
78
79
80
+ _T = TypeVar ("_T" , bound = "AutotuneArgs" )
81
+
82
+
83
+ @dataclasses .dataclass
84
+ class BenchmarkTensors :
85
+ """Represents a set of inputs and outputs for autotuning with a template"""
86
+
87
+ input_tensors : List [torch .Tensor ]
88
+ output_tensor : Optional [torch .Tensor ]
89
+
90
+ def unpack (self ):
91
+ return self .input_tensors , self .output_tensor
92
+
93
+
94
+ @dataclasses .dataclass
95
+ class AutotuneArgs :
96
+ """During autotuning, we need to pass the same inputs to all choices.
97
+ Note:
98
+ Since we typically have a mix of external choices and triton choices, we create
99
+ two lists of inputs for the same underlying buffers:
100
+ - External inputs (for aten kernels): Include offset for sliced tensors
101
+ - Triton inputs: Use base pointer for sliced tensors, without offset
102
+ """
103
+
104
+ triton : BenchmarkTensors
105
+ extern : BenchmarkTensors
106
+ expected : Optional [torch .Tensor ] = None
107
+
108
+ def get_benchmark_tensors (self , extern = False ) -> BenchmarkTensors :
109
+ """Returns the inputs and output tensors for a given choice."""
110
+ bench_tensors = self .extern if extern else self .triton
111
+ return bench_tensors
112
+
113
+ @classmethod
114
+ def from_choice_args (
115
+ cls : Type [_T ],
116
+ example_inputs : List [torch .Tensor ],
117
+ example_inputs_extern : List [torch .Tensor ],
118
+ out : torch .Tensor ,
119
+ out_extern : torch .Tensor ,
120
+ expected : Optional [torch .Tensor ] = None ,
121
+ ) -> _T :
122
+ """Factory method to create AutotuneInputs from separate inputs/outputs"""
123
+ return cls (
124
+ triton = BenchmarkTensors (example_inputs , out ),
125
+ extern = BenchmarkTensors (example_inputs_extern , out_extern ),
126
+ expected = expected ,
127
+ )
128
+
129
+ def verify (self , ** kwargs ):
130
+ """Verify the correctness of the benchmarking results"""
131
+
132
+ torch .testing .assert_close (self .extern .output_tensor , self .expected , ** kwargs )
133
+
134
+
79
135
class PartialRender :
80
136
"""
81
137
Some parts of a template need to be generated at the end, but
@@ -1456,7 +1512,9 @@ def make_benchmark_fn(
1456
1512
if input_gen_fns is None :
1457
1513
input_gen_fns = {}
1458
1514
1459
- def get_inputs (choices : List [ChoiceCaller ]):
1515
+ def get_inputs (
1516
+ choices : Union [List [ExternKernelCaller ], List [TritonTemplateCaller ]]
1517
+ ) -> AutotuneArgs :
1460
1518
# de-duplicate args
1461
1519
unique_example_inputs = {
1462
1520
x .get_name (): input_gen_fns .get (i , cls .benchmark_example_value )(x )
@@ -1489,55 +1547,44 @@ def get_inputs(choices: List[ChoiceCaller]):
1489
1547
out_extern = torch .as_strided (
1490
1548
out , out .size (), out .stride (), V .graph .sizevars .size_hint (layout .offset )
1491
1549
)
1492
-
1493
1550
expected = None
1494
1551
if VERIFY :
1495
1552
choices [0 ].benchmark (* example_inputs_extern , out = out_extern )
1496
1553
expected = out_extern .clone ()
1497
1554
1498
- return example_inputs , example_inputs_extern , out , out_extern , expected
1555
+ return AutotuneArgs .from_choice_args (
1556
+ example_inputs ,
1557
+ example_inputs_extern ,
1558
+ out ,
1559
+ out_extern ,
1560
+ expected ,
1561
+ )
1499
1562
1500
1563
if DEBUG :
1501
1564
print (f"{ len (choices )} tuning requests:" )
1502
1565
1503
- def debug_str (example_inputs , out ):
1504
- def tensor_repr (x ):
1505
- return (
1506
- f"torch.empty_strided({ tuple (x .size ())!r} , { tuple (x .stride ())!r} , "
1507
- f"dtype={ x .dtype !r} , device={ x .device .type !r} )"
1508
- )
1509
-
1510
- lines = [
1511
- "inputs = [" ,
1512
- ]
1513
- for x in example_inputs :
1514
- lines .append (f" { tensor_repr (x )} ," )
1515
- lines += ["]" , f"out = { tensor_repr (out )} " , "" ]
1516
- return "\n " .join (lines )
1517
-
1518
1566
def benchmark_choice_in_current_process (
1519
- choice , example_inputs , example_inputs_extern , out , out_extern , expected
1520
- ):
1521
- out .zero_ ()
1522
- if isinstance (choice , ExternKernelCaller ):
1523
- # aten kernels want the offset baked in for sliced tensors
1524
- result = choice .benchmark (* example_inputs_extern , out = out_extern )
1525
- else :
1526
- # triton templates want the base pointer for sliced tensors
1527
- result = choice .benchmark (* example_inputs , out = out )
1528
- if VERIFY and expected is not None :
1529
- torch .testing .assert_close (out_extern , expected , ** VERIFY )
1567
+ choice : ChoiceCaller , autotune_args : AutotuneArgs
1568
+ ) -> float :
1569
+ is_extern = isinstance (choice , ExternKernelCaller )
1570
+ benchmark_tensors = autotune_args .get_benchmark_tensors (is_extern )
1571
+ inpts , output = benchmark_tensors .unpack ()
1572
+ output .zero_ ()
1573
+ result = choice .benchmark (* inpts , out = output )
1574
+ if VERIFY and autotune_args .expected is not None :
1575
+ autotune_args .verify (** VERIFY )
1530
1576
if torch .cuda .is_available ():
1531
1577
torch .cuda .synchronize () # shake out any CUDA errors
1532
1578
return result
1533
1579
1534
- def benchmark_in_current_process (choices ):
1580
+ def benchmark_in_current_process (
1581
+ choices : Union [List [ExternKernelCaller ], List [TritonTemplateCaller ]],
1582
+ ) -> Dict [Union [ExternKernelCaller , TritonTemplateCaller ], float ]:
1535
1583
inputs = get_inputs (choices )
1536
- example_inputs , _ , out , _ , _ = inputs
1537
1584
timings = {}
1538
1585
for choice in choices :
1539
1586
try :
1540
- timing = benchmark_choice_in_current_process (choice , * inputs )
1587
+ timing = benchmark_choice_in_current_process (choice , inputs )
1541
1588
except CUDACompileError as e :
1542
1589
log .error (
1543
1590
"CUDA compilation error during autotuning: \n %s. \n Ignoring this choice." ,
@@ -1579,7 +1626,9 @@ def benchmark_in_current_process(choices):
1579
1626
1580
1627
return timings
1581
1628
1582
- def benchmark_in_sub_process (choices ):
1629
+ def benchmark_in_sub_process (
1630
+ choices : Union [List [ExternKernelCaller ], List [TritonTemplateCaller ]]
1631
+ ):
1583
1632
from . import autotune_process
1584
1633
1585
1634
# only benchmark triton kernel in sub process for now.
@@ -1588,7 +1637,7 @@ def benchmark_in_sub_process(choices):
1588
1637
triton = [c for c in choices if not isinstance (c , ExternKernelCaller )]
1589
1638
1590
1639
timings = benchmark_in_current_process (extern )
1591
- timings .update (autotune_process .benchmark_in_sub_process (triton ))
1640
+ timings .update (autotune_process .benchmark_in_sub_process (triton )) # type: ignore[arg-type]
1592
1641
return timings
1593
1642
1594
1643
benchmark = (
0 commit comments