Skip to content

Commit 53c9c19

Browse files
drisspgpytorchmergebot
authored andcommitted
[Autotune Inductor] Some clean up and dataclassing (pytorch#139157)
Pull Request resolved: pytorch#139157 Approved by: https://github.com/eellison
1 parent c2d7544 commit 53c9c19

File tree

1 file changed

+84
-35
lines changed

1 file changed

+84
-35
lines changed

torch/_inductor/select_algorithm.py

+84-35
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# mypy: allow-untyped-defs
22
import builtins
33
import contextlib
4+
import dataclasses
45
import functools
56
import inspect
67
import itertools
@@ -15,7 +16,7 @@
1516
from collections import namedtuple
1617
from concurrent.futures import as_completed, ThreadPoolExecutor
1718
from io import StringIO
18-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
19+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
1920
from unittest.mock import patch
2021

2122
import sympy
@@ -76,6 +77,61 @@ class KernelNamespace:
7677
extern_kernels = KernelNamespace()
7778

7879

80+
_T = TypeVar("_T", bound="AutotuneArgs")
81+
82+
83+
@dataclasses.dataclass
84+
class BenchmarkTensors:
85+
"""Represents a set of inputs and outputs for autotuning with a template"""
86+
87+
input_tensors: List[torch.Tensor]
88+
output_tensor: Optional[torch.Tensor]
89+
90+
def unpack(self):
91+
return self.input_tensors, self.output_tensor
92+
93+
94+
@dataclasses.dataclass
95+
class AutotuneArgs:
96+
"""During autotuning, we need to pass the same inputs to all choices.
97+
Note:
98+
Since we typically have a mix of external choices and triton choices, we create
99+
two lists of inputs for the same underlying buffers:
100+
- External inputs (for aten kernels): Include offset for sliced tensors
101+
- Triton inputs: Use base pointer for sliced tensors, without offset
102+
"""
103+
104+
triton: BenchmarkTensors
105+
extern: BenchmarkTensors
106+
expected: Optional[torch.Tensor] = None
107+
108+
def get_benchmark_tensors(self, extern=False) -> BenchmarkTensors:
109+
"""Returns the inputs and output tensors for a given choice."""
110+
bench_tensors = self.extern if extern else self.triton
111+
return bench_tensors
112+
113+
@classmethod
114+
def from_choice_args(
115+
cls: Type[_T],
116+
example_inputs: List[torch.Tensor],
117+
example_inputs_extern: List[torch.Tensor],
118+
out: torch.Tensor,
119+
out_extern: torch.Tensor,
120+
expected: Optional[torch.Tensor] = None,
121+
) -> _T:
122+
"""Factory method to create AutotuneInputs from separate inputs/outputs"""
123+
return cls(
124+
triton=BenchmarkTensors(example_inputs, out),
125+
extern=BenchmarkTensors(example_inputs_extern, out_extern),
126+
expected=expected,
127+
)
128+
129+
def verify(self, **kwargs):
130+
"""Verify the correctness of the benchmarking results"""
131+
132+
torch.testing.assert_close(self.extern.output_tensor, self.expected, **kwargs)
133+
134+
79135
class PartialRender:
80136
"""
81137
Some parts of a template need to be generated at the end, but
@@ -1456,7 +1512,9 @@ def make_benchmark_fn(
14561512
if input_gen_fns is None:
14571513
input_gen_fns = {}
14581514

1459-
def get_inputs(choices: List[ChoiceCaller]):
1515+
def get_inputs(
1516+
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]]
1517+
) -> AutotuneArgs:
14601518
# de-duplicate args
14611519
unique_example_inputs = {
14621520
x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
@@ -1489,55 +1547,44 @@ def get_inputs(choices: List[ChoiceCaller]):
14891547
out_extern = torch.as_strided(
14901548
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
14911549
)
1492-
14931550
expected = None
14941551
if VERIFY:
14951552
choices[0].benchmark(*example_inputs_extern, out=out_extern)
14961553
expected = out_extern.clone()
14971554

1498-
return example_inputs, example_inputs_extern, out, out_extern, expected
1555+
return AutotuneArgs.from_choice_args(
1556+
example_inputs,
1557+
example_inputs_extern,
1558+
out,
1559+
out_extern,
1560+
expected,
1561+
)
14991562

15001563
if DEBUG:
15011564
print(f"{len(choices)} tuning requests:")
15021565

1503-
def debug_str(example_inputs, out):
1504-
def tensor_repr(x):
1505-
return (
1506-
f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, "
1507-
f"dtype={x.dtype!r}, device={x.device.type!r})"
1508-
)
1509-
1510-
lines = [
1511-
"inputs = [",
1512-
]
1513-
for x in example_inputs:
1514-
lines.append(f" {tensor_repr(x)},")
1515-
lines += ["]", f"out = {tensor_repr(out)}", ""]
1516-
return "\n".join(lines)
1517-
15181566
def benchmark_choice_in_current_process(
1519-
choice, example_inputs, example_inputs_extern, out, out_extern, expected
1520-
):
1521-
out.zero_()
1522-
if isinstance(choice, ExternKernelCaller):
1523-
# aten kernels want the offset baked in for sliced tensors
1524-
result = choice.benchmark(*example_inputs_extern, out=out_extern)
1525-
else:
1526-
# triton templates want the base pointer for sliced tensors
1527-
result = choice.benchmark(*example_inputs, out=out)
1528-
if VERIFY and expected is not None:
1529-
torch.testing.assert_close(out_extern, expected, **VERIFY)
1567+
choice: ChoiceCaller, autotune_args: AutotuneArgs
1568+
) -> float:
1569+
is_extern = isinstance(choice, ExternKernelCaller)
1570+
benchmark_tensors = autotune_args.get_benchmark_tensors(is_extern)
1571+
inpts, output = benchmark_tensors.unpack()
1572+
output.zero_()
1573+
result = choice.benchmark(*inpts, out=output)
1574+
if VERIFY and autotune_args.expected is not None:
1575+
autotune_args.verify(**VERIFY)
15301576
if torch.cuda.is_available():
15311577
torch.cuda.synchronize() # shake out any CUDA errors
15321578
return result
15331579

1534-
def benchmark_in_current_process(choices):
1580+
def benchmark_in_current_process(
1581+
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]],
1582+
) -> Dict[Union[ExternKernelCaller, TritonTemplateCaller], float]:
15351583
inputs = get_inputs(choices)
1536-
example_inputs, _, out, _, _ = inputs
15371584
timings = {}
15381585
for choice in choices:
15391586
try:
1540-
timing = benchmark_choice_in_current_process(choice, *inputs)
1587+
timing = benchmark_choice_in_current_process(choice, inputs)
15411588
except CUDACompileError as e:
15421589
log.error(
15431590
"CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.",
@@ -1579,7 +1626,9 @@ def benchmark_in_current_process(choices):
15791626

15801627
return timings
15811628

1582-
def benchmark_in_sub_process(choices):
1629+
def benchmark_in_sub_process(
1630+
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]]
1631+
):
15831632
from . import autotune_process
15841633

15851634
# only benchmark triton kernel in sub process for now.
@@ -1588,7 +1637,7 @@ def benchmark_in_sub_process(choices):
15881637
triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
15891638

15901639
timings = benchmark_in_current_process(extern)
1591-
timings.update(autotune_process.benchmark_in_sub_process(triton))
1640+
timings.update(autotune_process.benchmark_in_sub_process(triton)) # type: ignore[arg-type]
15921641
return timings
15931642

15941643
benchmark = (

0 commit comments

Comments
 (0)