From f3dfd05c6a25e2afbc2dd5d48cc781af17fe284e Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Mon, 17 Feb 2025 17:38:59 -0500 Subject: [PATCH] Refactor/cubecl fusion (#2815) --- .github/workflows/publish.yml | 14 + Cargo.lock | 57 +- Cargo.toml | 4 +- crates/burn-core/Cargo.toml | 104 +- crates/burn-core/src/lib.rs | 6 - crates/burn-cubecl-fusion/Cargo.toml | 35 + crates/burn-cubecl-fusion/README.md | 3 + crates/burn-cubecl-fusion/src/base.rs | 165 ++ .../src}/elemwise/builder.rs | 14 +- crates/burn-cubecl-fusion/src/elemwise/mod.rs | 2 + .../src/elemwise/optimization.rs | 127 ++ .../mod.rs => burn-cubecl-fusion/src/lib.rs} | 8 +- .../src}/matmul/args.rs | 81 +- .../src}/matmul/builder.rs | 26 +- crates/burn-cubecl-fusion/src/matmul/mod.rs | 8 + .../src}/matmul/optimization.rs | 132 +- .../src}/matmul/spec.rs | 0 .../src}/matmul/tune.rs | 51 +- .../burn-cubecl-fusion/src/on_write/base.rs | 1 + .../src}/on_write/builder.rs | 0 crates/burn-cubecl-fusion/src/on_write/io.rs | 432 ++++++ .../src}/on_write/ir.rs | 166 ++- .../burn-cubecl-fusion/src/on_write/kernel.rs | 609 ++++++++ .../src}/on_write/mod.rs | 4 + .../src}/on_write/settings.rs | 0 .../burn-cubecl-fusion/src/on_write/tensor.rs | 288 ++++ .../src}/on_write/trace/base.rs | 48 +- .../src}/on_write/trace/builder.rs | 22 +- .../src/on_write/trace/executor.rs | 251 ++++ .../src}/on_write/trace/input.rs | 31 +- .../src}/on_write/trace/mod.rs | 0 .../src}/on_write/trace/output.rs | 102 +- .../src}/on_write/trace/plan.rs | 16 +- .../src}/on_write/trace/runner.rs | 6 +- .../src}/on_write/trace/vectorization.rs | 12 +- .../fusion => burn-cubecl-fusion/src}/tune.rs | 20 +- crates/burn-cubecl/Cargo.toml | 3 +- crates/burn-cubecl/README.md | 2 +- .../src/{fusion/base.rs => fusion.rs} | 198 +-- crates/burn-cubecl/src/fusion/elemwise/mod.rs | 2 - .../src/fusion/elemwise/optimization.rs | 178 --- crates/burn-cubecl/src/fusion/matmul/mod.rs | 5 - crates/burn-cubecl/src/fusion/on_write/io.rs | 1155 -------------- .../burn-cubecl/src/fusion/on_write/kernel.rs | 1325 ----------------- .../src/fusion/on_write/trace/executor.rs | 230 --- .../conv/conv2d/gemm/homogeneous/base.rs | 6 +- .../kernel/conv/conv2d/gemm/loader/im2col.rs | 25 +- crates/burn-cubecl/src/lib.rs | 27 +- crates/burn/Cargo.toml | 118 +- crates/{burn-core => burn}/src/backend.rs | 0 crates/burn/src/lib.rs | 6 + 51 files changed, 2549 insertions(+), 3576 deletions(-) create mode 100644 crates/burn-cubecl-fusion/Cargo.toml create mode 100644 crates/burn-cubecl-fusion/README.md create mode 100644 crates/burn-cubecl-fusion/src/base.rs rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/elemwise/builder.rs (82%) create mode 100644 crates/burn-cubecl-fusion/src/elemwise/mod.rs create mode 100644 crates/burn-cubecl-fusion/src/elemwise/optimization.rs rename crates/{burn-cubecl/src/fusion/mod.rs => burn-cubecl-fusion/src/lib.rs} (50%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/matmul/args.rs (72%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/matmul/builder.rs (84%) create mode 100644 crates/burn-cubecl-fusion/src/matmul/mod.rs rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/matmul/optimization.rs (78%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/matmul/spec.rs (100%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/matmul/tune.rs (69%) create mode 100644 crates/burn-cubecl-fusion/src/on_write/base.rs rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/builder.rs (100%) create mode 100644 crates/burn-cubecl-fusion/src/on_write/io.rs rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/ir.rs (66%) create mode 100644 crates/burn-cubecl-fusion/src/on_write/kernel.rs rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/mod.rs (68%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/settings.rs (100%) create mode 100644 crates/burn-cubecl-fusion/src/on_write/tensor.rs rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/trace/base.rs (81%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/trace/builder.rs (97%) create mode 100644 crates/burn-cubecl-fusion/src/on_write/trace/executor.rs rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/trace/input.rs (77%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/trace/mod.rs (100%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/trace/output.rs (81%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/trace/plan.rs (86%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/trace/runner.rs (98%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/on_write/trace/vectorization.rs (94%) rename crates/{burn-cubecl/src/fusion => burn-cubecl-fusion/src}/tune.rs (87%) rename crates/burn-cubecl/src/{fusion/base.rs => fusion.rs} (54%) delete mode 100644 crates/burn-cubecl/src/fusion/elemwise/mod.rs delete mode 100644 crates/burn-cubecl/src/fusion/elemwise/optimization.rs delete mode 100644 crates/burn-cubecl/src/fusion/matmul/mod.rs delete mode 100644 crates/burn-cubecl/src/fusion/on_write/io.rs delete mode 100644 crates/burn-cubecl/src/fusion/on_write/kernel.rs delete mode 100644 crates/burn-cubecl/src/fusion/on_write/trace/executor.rs rename crates/{burn-core => burn}/src/backend.rs (100%) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index c37fd9684d..d4b0b9700b 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -14,6 +14,7 @@ jobs: - publish-burn-autodiff - publish-burn-candle - publish-burn-fusion + - publish-burn-cubecl-fusion - publish-burn-cubecl - publish-burn-ndarray - publish-burn-tch @@ -113,12 +114,25 @@ jobs: secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + publish-burn-cubecl-fusion: + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 + needs: + - publish-burn-ir + - publish-burn-common + - publish-burn-fusion + - publish-burn-tensor + with: + crate: burn-cubecl-fusion + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + publish-burn-cubecl: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 needs: - publish-burn-ir - publish-burn-common - publish-burn-fusion + - publish-burn-cubecl-fusion - publish-burn-tensor - publish-burn-ndarray with: diff --git a/Cargo.lock b/Cargo.lock index 51cb39fcb8..b1f28aa0fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -555,8 +555,17 @@ checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" name = "burn" version = "0.17.0" dependencies = [ + "burn-autodiff", + "burn-candle", "burn-core", + "burn-cuda", + "burn-hip", + "burn-ndarray", + "burn-remote", + "burn-router", + "burn-tch", "burn-train", + "burn-wgpu", ] [[package]] @@ -605,7 +614,6 @@ dependencies = [ "ahash", "bincode", "burn-autodiff", - "burn-candle", "burn-common", "burn-cuda", "burn-dataset", @@ -642,6 +650,7 @@ version = "0.17.0" dependencies = [ "burn-autodiff", "burn-common", + "burn-cubecl-fusion", "burn-fusion", "burn-ir", "burn-ndarray", @@ -663,6 +672,20 @@ dependencies = [ "text_placeholder", ] +[[package]] +name = "burn-cubecl-fusion" +version = "0.17.0" +dependencies = [ + "burn-common", + "burn-fusion", + "burn-ir", + "burn-tensor", + "cubecl", + "derive-new 0.7.0", + "half", + "serde", +] + [[package]] name = "burn-cuda" version = "0.17.0" @@ -1508,7 +1531,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1523,7 +1546,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1544,7 +1567,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1565,7 +1588,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "bytemuck", "cubecl-common", @@ -1579,7 +1602,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "bytemuck", "cubecl-common", @@ -1595,7 +1618,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "bytemuck", "cubecl-common", @@ -1621,7 +1644,7 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -1639,7 +1662,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "bytemuck", "cubecl-core", @@ -1652,7 +1675,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "cubecl-common", "darling", @@ -1667,7 +1690,7 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "darling", "proc-macro2", @@ -1678,7 +1701,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1694,7 +1717,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1704,7 +1727,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "async-channel", "async-lock", @@ -1726,7 +1749,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1741,7 +1764,7 @@ dependencies = [ [[package]] name = "cubecl-std" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1750,7 +1773,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7f07d3969bc7d69c6ae2f87bd806dd4f18267741#7f07d3969bc7d69c6ae2f87bd806dd4f18267741" +source = "git+https://github.com/tracel-ai/cubecl?rev=8b025f26e5badbf1b8f3e6787fc097427cd961ec#8b025f26e5badbf1b8f3e6787fc097427cd961ec" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 23bb8df53a..c917977f1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -155,8 +155,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7f07d3969bc7d69c6ae2f87bd806dd4f18267741" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7f07d3969bc7d69c6ae2f87bd806dd4f18267741" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8b025f26e5badbf1b8f3e6787fc097427cd961ec" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8b025f26e5badbf1b8f3e6787fc097427cd961ec" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index 423dc784d8..cf4b2898b8 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -15,61 +15,27 @@ version.workspace = true dataset = ["burn-dataset"] default = [ "std", - "burn-candle?/default", "burn-common/default", "burn-dataset?/default", - "burn-ndarray?/default", - "burn-tch?/default", "burn-tensor/default", - "burn-wgpu?/default", - "burn-router?/default", - "burn-cuda?/default", - "burn-autodiff?/default", - "burn-hip?/default", ] doc = [ "std", - # Backends "dataset", - "candle", - "fusion", - "ndarray", - "tch", - "wgpu", - "cuda", - "hip", "audio", "vision", - "autodiff", - "remote", - "router", - "server", # Doc features - "burn-candle/doc", "burn-common/doc", "burn-dataset/doc", - "burn-ndarray/doc", - "burn-tch/doc", "burn-tensor/doc", - "burn-wgpu/doc", - "burn-router/doc", - "burn-cuda/doc", - "burn-hip/doc", ] network = ["burn-common/network"] sqlite = ["burn-dataset?/sqlite"] sqlite-bundled = ["burn-dataset?/sqlite-bundled"] std = [ - "burn-autodiff?/std", "bincode/std", - "burn-candle?/std", "burn-common/std", - "burn-ndarray?/std", "burn-tensor/std", - "burn-wgpu?/std", - "burn-router?/std", - "burn-cuda?/std", - "burn-hip?/std", "flate2", "half/std", "log", @@ -82,45 +48,27 @@ std = [ vision = ["burn-dataset?/vision", "burn-common/network"] audio = ["burn-dataset?/audio"] -# Backend -autodiff = ["burn-autodiff"] -fusion = ["burn-wgpu?/fusion", "burn-cuda?/fusion"] - -## Backend features -accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"] -autotune = ["burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-hip?/autotune"] -blas-netlib = ["burn-ndarray?/blas-netlib"] -metal = ["burn-candle?/metal"] -openblas = ["burn-ndarray?/blas-openblas"] -openblas-system = ["burn-ndarray?/blas-openblas-system"] -remote = ["burn-remote/client"] -router = ["burn-router"] -server = ["burn-remote/server"] -template = ["burn-wgpu?/template"] - -candle = ["burn-candle"] -candle-cuda = ["candle", "burn-candle/cuda"] -cuda = ["burn-cuda"] -hip = ["burn-hip"] -ndarray = ["burn-ndarray"] -tch = ["burn-tch"] -wgpu = ["burn-wgpu"] -vulkan = ["wgpu", "burn-wgpu/vulkan"] -webgpu = ["wgpu", "burn-wgpu/webgpu"] - # Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files. record-item-custom-serde = ["thiserror", "regex"] # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] -test-cuda = ["cuda"] # To use cuda during testing, default uses ndarray. -test-hip = ["hip"] # To use hip during testing, default uses ndarray. -test-tch = ["tch"] # To use tch during testing, default uses ndarray. -test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray. +test-cuda = [ + "burn-cuda/default", +] # To use cuda during testing, default uses ndarray. +test-hip = [ + "burn-hip/default", +] # To use hip during testing, default uses ndarray. +test-tch = [ + "burn-tch/default", +] # To use tch during testing, default uses ndarray. +test-wgpu = [ + "burn-wgpu/default", +] # To use wgpu during testing, default uses ndarray. test-wgpu-spirv = [ - "test-wgpu", - "vulkan", + "burn-wgpu/default", + "burn-wgpu/vulkan", ] # To use wgpu-spirv during testing, default uses ndarray. [dependencies] @@ -132,17 +80,6 @@ burn-dataset = { path = "../burn-dataset", version = "0.17.0", optional = true, burn-derive = { path = "../burn-derive", version = "0.17.0" } burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } -# Backends -burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } -burn-candle = { path = "../burn-candle", version = "0.17.0", optional = true } -burn-cuda = { path = "../burn-cuda", version = "0.17.0", optional = true, default-features = false } -burn-hip = { path = "../burn-hip", version = "0.17.0", optional = true, default-features = false } -burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true, default-features = false } -burn-remote = { path = "../burn-remote", version = "0.17.0", default-features = false, optional = true } -burn-router = { path = "../burn-router", version = "0.17.0", default-features = false, optional = true } -burn-tch = { path = "../burn-tch", version = "0.17.0", optional = true } -burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", optional = true, default-features = false } - data-encoding = { workspace = true } uuid = { workspace = true } @@ -167,18 +104,25 @@ serde_json = { workspace = true, features = ["alloc"] } #Default enables std spin = { workspace = true } # Using in place of use std::sync::Mutex when std is disabled thiserror = { workspace = true, optional = true } +# FOR TESTING +burn-cuda = { path = "../burn-cuda", version = "0.17.0", optional = true, default-features = false } +burn-hip = { path = "../burn-hip", version = "0.17.0", optional = true, default-features = false } +burn-remote = { path = "../burn-remote", version = "0.17.0", default-features = false, optional = true } +burn-router = { path = "../burn-router", version = "0.17.0", default-features = false, optional = true } +burn-tch = { path = "../burn-tch", version = "0.17.0", optional = true } +burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", optional = true, default-features = false } + [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] portable-atomic-util = { workspace = true } [dev-dependencies] +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0" } burn-dataset = { path = "../burn-dataset", version = "0.17.0", features = [ "fake", ] } tempfile = { workspace = true } -burn-autodiff = { path = "../burn-autodiff", version = "0.17.0" } -burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false } - [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs index 34887ec6b9..a061c02604 100644 --- a/crates/burn-core/src/lib.rs +++ b/crates/burn-core/src/lib.rs @@ -41,12 +41,6 @@ pub mod record; /// Module for the tensor. pub mod tensor; -/// Backend module. -pub mod backend; - -#[cfg(feature = "server")] -pub use burn_remote::server; - extern crate alloc; /// Backend for test cases diff --git a/crates/burn-cubecl-fusion/Cargo.toml b/crates/burn-cubecl-fusion/Cargo.toml new file mode 100644 index 0000000000..923079df60 --- /dev/null +++ b/crates/burn-cubecl-fusion/Cargo.toml @@ -0,0 +1,35 @@ +[package] +authors = ["nathanielsimard "] +categories = ["science"] +description = "Provide optimizations that can be used with cubecl based backends." +documentation = "https://docs.rs/burn-cubecl-fusion" +edition.workspace = true +keywords = ["deep-learning", "machine-learning", "gpu"] +license.workspace = true +name = "burn-cubecl-fusion" +readme.workspace = true +repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cubecl-fusion" +version.workspace = true + +[features] +autotune = [] +default = ["autotune", "std", "cubecl/default"] +doc = ["default"] +std = ["cubecl/std", "burn-tensor/std"] + +[dependencies] +burn-common = { path = "../burn-common", version = "0.17.0" } +burn-fusion = { path = "../burn-fusion", version = "0.17.0" } +burn-ir = { path = "../burn-ir", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ + "cubecl", +] } +cubecl = { workspace = true, features = ["linalg"] } + +half = { workspace = true } +serde = { workspace = true } +derive-new = { workspace = true } + +[package.metadata.docs.rs] +features = ["doc"] +rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/burn-cubecl-fusion/README.md b/crates/burn-cubecl-fusion/README.md new file mode 100644 index 0000000000..88c579785f --- /dev/null +++ b/crates/burn-cubecl-fusion/README.md @@ -0,0 +1,3 @@ +# Burn CubeCl Fusion + +Provide optimizations that can be used with [cubecl](../burn-cubecl) based backends. diff --git a/crates/burn-cubecl-fusion/src/base.rs b/crates/burn-cubecl-fusion/src/base.rs new file mode 100644 index 0000000000..a3cf82f586 --- /dev/null +++ b/crates/burn-cubecl-fusion/src/base.rs @@ -0,0 +1,165 @@ +use std::marker::PhantomData; + +use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState}; +use super::matmul::optimization::{MatmulOptimization, MatmulOptimizationState}; + +use burn_tensor::DType; +use cubecl::client::ComputeClient; +use cubecl::ir::Elem; +use cubecl::prelude::{TensorArg, TensorHandleRef}; +use cubecl::{CubeElement, Runtime}; +use serde::{Deserialize, Serialize}; + +/// Fusion optimization type for cubecl. +/// +/// More optimization variants should be added here. +pub enum CubeOptimization { + /// Element wise optimization. + ElementWise(ElemwiseOptimization), + /// Matrix multiplication optimization. + Matmul(MatmulOptimization), +} + +/// Fusion optimization state type for cubecl. +/// +/// More optimization variants should be added here. +#[derive(Serialize, Deserialize)] +pub enum CubeOptimizationState { + /// Element wise state. + ElementWise(ElemwiseOptimizationState), + /// Matrix multiplication optimization state. + Matmul(MatmulOptimizationState), +} + +pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Vec { + let mut strides = vec![0; shape.len()]; + + let mut current = 1; + shape.iter().enumerate().rev().for_each(|(index, val)| { + strides[index] = current; + current *= val; + }); + + strides +} + +pub(crate) fn elem_dtype() -> DType { + match E::cube_elem() { + Elem::Float(kind) => match kind { + cubecl::ir::FloatKind::F16 => DType::F16, + cubecl::ir::FloatKind::BF16 => DType::BF16, + cubecl::ir::FloatKind::F32 => DType::F32, + _ => todo!(), + }, + Elem::Int(kind) => match kind { + cubecl::ir::IntKind::I64 => DType::I64, + cubecl::ir::IntKind::I32 => DType::I32, + cubecl::ir::IntKind::I16 => DType::I16, + cubecl::ir::IntKind::I8 => DType::I8, + }, + Elem::UInt(kind) => match kind { + cubecl::ir::UIntKind::U64 => DType::U64, + cubecl::ir::UIntKind::U32 => DType::U32, + cubecl::ir::UIntKind::U16 => DType::U16, + cubecl::ir::UIntKind::U8 => DType::U8, + }, + Elem::Bool => DType::Bool, + _ => todo!(), + } +} + +/// Handle to be used when fusing operations. +pub struct CubeFusionHandle { + /// Compute client for jit. + pub client: ComputeClient, + /// The buffer where the data are stored. + pub handle: cubecl::server::Handle, + /// The device of the current tensor. + pub device: R::Device, + /// The element type of the tensor. + pub dtype: DType, + /// The strides of the tensor. + pub strides: Vec, +} + +impl core::fmt::Debug for CubeFusionHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!( + "CubeFusionHandle {{ device: {:?}, runtime: {}}}", + self.device, + R::name(), + )) + } +} + +impl Clone for CubeFusionHandle { + fn clone(&self) -> Self { + Self { + client: self.client.clone(), + handle: self.handle.clone(), + device: self.device.clone(), + strides: self.strides.clone(), + dtype: self.dtype, + } + } +} + +unsafe impl Send for CubeFusionHandle {} +unsafe impl Sync for CubeFusionHandle {} + +impl CubeFusionHandle { + /// Return the reference to a tensor handle. + pub fn as_handle_ref<'a>(&'a self, shape: &'a [usize]) -> TensorHandleRef<'a, R> { + TensorHandleRef { + handle: &self.handle, + strides: &self.strides, + shape, + runtime: PhantomData, + elem_size: self.dtype.size(), + } + } + /// Return the reference to a tensor argument. + pub fn as_tensor_arg<'a>(&'a self, shape: &'a [usize], vectorisation: u8) -> TensorArg<'a, R> { + let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape); + + unsafe { + TensorArg::from_raw_parts_and_size( + handle.handle, + handle.strides, + handle.shape, + vectorisation, + self.dtype.size(), + ) + } + } +} + +pub(crate) fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { + if shape.is_empty() { + return true; + } + + if shape.len() == 1 { + return strides[0] == 1; + } + + let mut prev_stride = 1; + let mut current_num_elems_shape = 1; + + for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() { + if i > 0 { + if current_num_elems_shape != *stride { + return false; + } + + if prev_stride >= *stride { + return false; + } + } + + current_num_elems_shape *= shape; + prev_stride = *stride; + } + + true +} diff --git a/crates/burn-cubecl/src/fusion/elemwise/builder.rs b/crates/burn-cubecl-fusion/src/elemwise/builder.rs similarity index 82% rename from crates/burn-cubecl/src/fusion/elemwise/builder.rs rename to crates/burn-cubecl-fusion/src/elemwise/builder.rs index 9fbff34bc4..ca68a01b10 100644 --- a/crates/burn-cubecl/src/fusion/elemwise/builder.rs +++ b/crates/burn-cubecl-fusion/src/elemwise/builder.rs @@ -1,22 +1,20 @@ use burn_fusion::OptimizationBuilder; +use cubecl::Runtime; use crate::{ - fusion::{ - on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings}, - CubeOptimization, - }, - CubeRuntime, + on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings}, + CubeOptimization, }; use super::optimization::ElemwiseOptimization; /// Fused element wise operations that are normally memory bound. -pub(crate) struct ElementWiseBuilder { +pub struct ElementWiseBuilder { builder: FuseOnWriteBuilder, device: R::Device, } -impl ElementWiseBuilder { +impl ElementWiseBuilder { pub fn new(device: R::Device, bool_precision: ElemwisePrecision) -> Self { let client = R::client(&device); let props = client.properties(); @@ -38,7 +36,7 @@ impl ElementWiseBuilder { } } -impl OptimizationBuilder> for ElementWiseBuilder { +impl OptimizationBuilder> for ElementWiseBuilder { fn register(&mut self, operation: &burn_ir::OperationIr) { self.builder.register(operation); } diff --git a/crates/burn-cubecl-fusion/src/elemwise/mod.rs b/crates/burn-cubecl-fusion/src/elemwise/mod.rs new file mode 100644 index 0000000000..a35c966dd5 --- /dev/null +++ b/crates/burn-cubecl-fusion/src/elemwise/mod.rs @@ -0,0 +1,2 @@ +pub mod builder; +pub mod optimization; diff --git a/crates/burn-cubecl-fusion/src/elemwise/optimization.rs b/crates/burn-cubecl-fusion/src/elemwise/optimization.rs new file mode 100644 index 0000000000..05f38bd1ec --- /dev/null +++ b/crates/burn-cubecl-fusion/src/elemwise/optimization.rs @@ -0,0 +1,127 @@ +use crate::on_write::ir::GlobalArgs; +use crate::on_write::{io::global_length, kernel::fuse_on_write}; +use crate::CubeFusionHandle; +use burn_fusion::stream::Context; +use cubecl::{calculate_cube_count_elemwise, client::ComputeClient, prelude::*, CubeDim}; +use serde::{Deserialize, Serialize}; + +use crate::on_write::{ + ir::{Arg, ElemwiseConfig, GlobalArgsLaunch}, + trace::{FuseOnWriteTrace, TraceRunner}, +}; + +#[derive(new)] +/// Fuse element wise operations into a single kernel. +pub struct ElemwiseOptimization { + trace: FuseOnWriteTrace, + client: ComputeClient, + device: R::Device, + len: usize, +} + +#[derive(Serialize, Deserialize)] +/// State for the [elemwise optimization](ElemwiseOptimization). +pub struct ElemwiseOptimizationState { + trace: FuseOnWriteTrace, + len: usize, +} + +impl ElemwiseOptimization { + /// Execute the optimization. + pub fn execute(&mut self, context: &mut Context<'_, CubeFusionHandle>) { + self.trace + .run::(&self.client, &self.device, context, &ElemwiseRunner) + .unwrap(); + } + + /// Number of element wise operations fused. + pub fn num_ops_fused(&self) -> usize { + self.len + } + + /// Create an optimization from its [state](ElemwiseOptimizationState). + pub fn from_state(device: &R::Device, state: ElemwiseOptimizationState) -> Self { + Self { + trace: state.trace, + len: state.len, + client: R::client(device), + device: device.clone(), + } + } + + /// Convert the optimization to its [state](ElemwiseOptimizationState). + pub fn to_state(&self) -> ElemwiseOptimizationState { + ElemwiseOptimizationState { + trace: self.trace.clone(), + len: self.len, + } + } +} + +pub struct ElemwiseRunner; + +impl TraceRunner for ElemwiseRunner { + type Error = (); // No error possible + + fn run<'a>( + &'a self, + client: &'a ComputeClient, + inputs: GlobalArgsLaunch<'a, R>, + outputs: GlobalArgsLaunch<'a, R>, + config: &'a ElemwiseConfig, + ) -> Result<(), Self::Error> { + let arg = match config.ref_layout { + Arg::Input(index, _, _) => inputs.tensors.values.get(index as usize), + Arg::Output(index, _, _) => outputs.tensors.values.get(index as usize), + _ => panic!("Invalid value"), + }; + let (shape, vectorization) = match arg { + Some(val) => match &val.tensor { + TensorArg::Handle { + handle, + vectorization_factor, + } => (handle.shape, vectorization_factor), + TensorArg::Alias { .. } => panic!("Can't be an alias, got {val:?}"), + }, + None => panic!("Invalid argument"), + }; + let total_elem = shape.iter().product::() / *vectorization as usize; + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); + + unsafe { + elemwise_fuse::launch_unchecked( + client, + cube_count, + cube_dim, + inputs, + outputs, + config.clone(), + ); + }; + + Ok(()) + } +} + +#[cube(launch_unchecked)] +fn elemwise_fuse( + inputs: &GlobalArgs, + outputs: &mut GlobalArgs, + #[comptime] config: &ElemwiseConfig, +) { + // We write no values for this fusion. + let values = Registry::>::new(); + let args = comptime![Sequence::::new()]; + let pos = ABSOLUTE_POS; + + let length = match comptime![config.ref_layout.clone()] { + Arg::Input(index, _, _) => global_length(inputs, index), + Arg::Output(index, _, _) => global_length(outputs, index), + _ => comptime![panic!("Invalid ref layout.")], + }; + + if pos < length { + fuse_on_write::(inputs, outputs, pos, values, args, config) + } +} diff --git a/crates/burn-cubecl/src/fusion/mod.rs b/crates/burn-cubecl-fusion/src/lib.rs similarity index 50% rename from crates/burn-cubecl/src/fusion/mod.rs rename to crates/burn-cubecl-fusion/src/lib.rs index 96e1704964..9cbedd7b9a 100644 --- a/crates/burn-cubecl/src/fusion/mod.rs +++ b/crates/burn-cubecl-fusion/src/lib.rs @@ -1,7 +1,11 @@ +#[macro_use] +extern crate derive_new; + +pub mod elemwise; +pub mod matmul; + mod base; -pub(crate) mod elemwise; -pub(crate) mod matmul; pub(crate) mod on_write; pub(crate) mod tune; diff --git a/crates/burn-cubecl/src/fusion/matmul/args.rs b/crates/burn-cubecl-fusion/src/matmul/args.rs similarity index 72% rename from crates/burn-cubecl/src/fusion/matmul/args.rs rename to crates/burn-cubecl-fusion/src/matmul/args.rs index 229e04496c..23847f62a1 100644 --- a/crates/burn-cubecl/src/fusion/matmul/args.rs +++ b/crates/burn-cubecl-fusion/src/matmul/args.rs @@ -1,6 +1,6 @@ use cubecl::{linalg::matmul::components::global::args::MatmulArgs, prelude::*}; -use crate::fusion::on_write::{ +use crate::on_write::{ io::{global_rank, global_shape, global_stride, read_input}, ir::{Arg, ElemwiseConfig, GlobalArgs, GlobalArgsExpand, LayoutInfo}, kernel::fuse_on_write, @@ -36,9 +36,9 @@ impl MatmulArgs for FusedMatmulArgs { } fn read_lhs(state: &Self::State, coordinate: u32) -> Line { - let (pos, precision) = comptime! { + let pos = comptime! { match state.lhs { - Arg::Input(pos, precision, _) => (pos, precision), + Arg::Input(pos, ..) => pos, _ => panic!("Lhs isn't an input"), } }; @@ -49,16 +49,15 @@ impl MatmulArgs for FusedMatmulArgs { pos, coordinate, LayoutInfo::IsRef, - precision, &state.config, None, ) } fn read_rhs(state: &Self::State, coordinate: u32) -> Line { - let (pos, precision) = comptime! { + let pos = comptime! { match state.rhs { - Arg::Input(pos, precision, _) => (pos, precision), + Arg::Input(pos, ..) => pos, _ => panic!("Lhs isn't an input"), } }; @@ -69,13 +68,11 @@ impl MatmulArgs for FusedMatmulArgs { pos, coordinate, LayoutInfo::IsRef, - precision, &state.config, None, ) } - #[allow(unreachable_code)] fn read_window_lhs( _state: &Self::State, _start: u32, @@ -85,6 +82,7 @@ impl MatmulArgs for FusedMatmulArgs { // TODO This is a dummy return value to satisfy the type checker // before working on an implementation. // Remove the allow annotation after implementing this function. + #[allow(unreachable_code)] SharedMemory::new_lined(0, 0_u32).to_slice() } @@ -98,6 +96,7 @@ impl MatmulArgs for FusedMatmulArgs { // TODO This is a dummy return value to satisfy the type checker // before working on an implementation. // Remove the allow annotation after implementing this function. + #[allow(unreachable_code)] SharedMemory::new_lined(0, 0_u32).to_slice() } @@ -119,116 +118,116 @@ impl MatmulArgs for FusedMatmulArgs { } fn rank_lhs(state: &Self::State) -> u32 { - let (pos, precision) = comptime! { + let pos = comptime! { match state.lhs { - Arg::Input(pos, precision, _) => (pos, precision), + Arg::Input(pos, ..) => pos, _ => panic!("Lhs isn't an input"), } }; - global_rank(unsafe { &(*state.inputs) }, pos, precision) + global_rank(unsafe { &(*state.inputs) }, pos) } fn rank_rhs(state: &Self::State) -> u32 { - let (pos, precision) = comptime! { + let pos = comptime! { match state.rhs { - Arg::Input(pos, precision, _) => (pos, precision), + Arg::Input(pos, ..) => pos, _ => panic!("Rhs isn't an input"), } }; - global_rank(unsafe { &(*state.inputs) }, pos, precision) + global_rank(unsafe { &(*state.inputs) }, pos) } fn rank_out(state: &Self::State) -> u32 { - let (pos, precision, is_input) = comptime! { + let (pos, is_input) = comptime! { match state.config.ref_layout { - Arg::Input(pos, precision, _) => (pos, precision, true), - Arg::Output(pos, precision, _) => (pos, precision, false), + Arg::Input(pos, ..) => (pos, true), + Arg::Output(pos, ..) => (pos, false), _ => panic!("Out isn't an input or output"), } }; if is_input { - global_rank(unsafe { &(*state.inputs) }, pos, precision) + global_rank(unsafe { &(*state.inputs) }, pos) } else { - global_rank(unsafe { &(*state.outputs) }, pos, precision) + global_rank(unsafe { &(*state.outputs) }, pos) } } fn shape_lhs(state: &Self::State, dim: u32) -> u32 { - let (pos, precision) = comptime! { + let pos = comptime! { match state.lhs { - Arg::Input(pos, precision, _) => (pos, precision), + Arg::Input(pos, ..) => pos, _ => panic!("Lhs isn't an input"), } }; - global_shape(unsafe { &(*state.inputs) }, dim, pos, precision) + global_shape(unsafe { &(*state.inputs) }, dim, pos) } fn shape_rhs(state: &Self::State, dim: u32) -> u32 { - let (pos, precision) = comptime! { + let pos = comptime! { match state.rhs { - Arg::Input(pos, precision, _) => (pos, precision), + Arg::Input(pos, ..) => pos, _ => panic!("Rhs isn't an input"), } }; - global_shape(unsafe { &(*state.inputs) }, dim, pos, precision) + global_shape(unsafe { &(*state.inputs) }, dim, pos) } fn shape_out(state: &Self::State, dim: u32) -> u32 { - let (pos, precision, is_input) = comptime! { + let (pos, is_input) = comptime! { match state.config.ref_layout { - Arg::Input(pos, precision, _) => (pos, precision, true), - Arg::Output(pos, precision, _) => (pos, precision, false), + Arg::Input(pos, ..) => (pos, true), + Arg::Output(pos, ..) => (pos, false), _ => panic!("Out isn't an input or output"), } }; if is_input { - global_shape(unsafe { &(*state.inputs) }, dim, pos, precision) + global_shape(unsafe { &(*state.inputs) }, dim, pos) } else { - global_shape(unsafe { &(*state.outputs) }, dim, pos, precision) + global_shape(unsafe { &(*state.outputs) }, dim, pos) } } fn stride_lhs(state: &Self::State, dim: u32) -> u32 { - let (pos, precision) = comptime! { + let pos = comptime! { match state.lhs { - Arg::Input(pos, precision, _) => (pos, precision), + Arg::Input(pos, ..) => pos, _ => panic!("Lhs isn't an input"), } }; - global_stride(unsafe { &(*state.inputs) }, dim, pos, precision) + global_stride(unsafe { &(*state.inputs) }, dim, pos) } fn stride_rhs(state: &Self::State, dim: u32) -> u32 { - let (pos, precision) = comptime! { + let pos = comptime! { match state.rhs { - Arg::Input(pos, precision, _) => (pos, precision), + Arg::Input(pos, ..) => pos, _ => panic!("Rhs isn't an input"), } }; - global_stride(unsafe { &(*state.inputs) }, dim, pos, precision) + global_stride(unsafe { &(*state.inputs) }, dim, pos) } fn stride_out(state: &Self::State, dim: u32) -> u32 { - let (pos, precision, is_input) = comptime! { + let (pos, is_input) = comptime! { match state.config.ref_layout { - Arg::Input(pos, precision, _) => (pos, precision, true), - Arg::Output(pos, precision, _) => (pos, precision, false), + Arg::Input(pos, ..) => (pos, true), + Arg::Output(pos, ..) => (pos, false), _ => panic!("Out isn't an input or output"), } }; if is_input { - global_stride(unsafe { &(*state.inputs) }, dim, pos, precision) + global_stride(unsafe { &(*state.inputs) }, dim, pos) } else { - global_stride(unsafe { &(*state.outputs) }, dim, pos, precision) + global_stride(unsafe { &(*state.outputs) }, dim, pos) } } } diff --git a/crates/burn-cubecl/src/fusion/matmul/builder.rs b/crates/burn-cubecl-fusion/src/matmul/builder.rs similarity index 84% rename from crates/burn-cubecl/src/fusion/matmul/builder.rs rename to crates/burn-cubecl-fusion/src/matmul/builder.rs index 59b7cecafe..4ca83512ff 100644 --- a/crates/burn-cubecl/src/fusion/matmul/builder.rs +++ b/crates/burn-cubecl-fusion/src/matmul/builder.rs @@ -1,26 +1,32 @@ +use std::sync::Arc; + +use super::MatmulFallbackFn; use burn_fusion::{OptimizationBuilder, OptimizationStatus}; use burn_ir::{FloatOperationIr, OperationIr}; +use cubecl::Runtime; use crate::{ - fusion::{ - on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings}, - CubeOptimization, - }, - CubeRuntime, + on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings}, + CubeOptimization, }; use super::optimization::{FusedMatmul, MatmulOptimization}; /// Fused element wise operations that are normally memory bound. -pub(crate) struct MatmulBuilder { +pub struct MatmulBuilder { builder: FuseOnWriteBuilder, builder_fallback: FuseOnWriteBuilder, device: R::Device, matmul: Option, + fallback: Arc>, } -impl MatmulBuilder { - pub fn new(device: R::Device, bool_precision: ElemwisePrecision) -> Self { +impl MatmulBuilder { + pub fn new( + device: R::Device, + bool_precision: ElemwisePrecision, + fallback: Arc>, + ) -> Self { let client = R::client(&device); let props = client.properties(); let max_bindings = props.hardware_properties().max_bindings; @@ -36,11 +42,12 @@ impl MatmulBuilder { builder_fallback: FuseOnWriteBuilder::new(max_bindings, bool_precision, settings), device, matmul: None, + fallback, } } } -impl OptimizationBuilder> for MatmulBuilder { +impl OptimizationBuilder> for MatmulBuilder { fn register(&mut self, operation: &OperationIr) { if let OptimizationStatus::Closed = self.builder.status() { return; @@ -86,6 +93,7 @@ impl OptimizationBuilder> for MatmulBuilder< self.device.clone(), self.len(), self.matmul.as_ref().unwrap().clone(), + self.fallback.clone(), ); CubeOptimization::Matmul(matmul) diff --git a/crates/burn-cubecl-fusion/src/matmul/mod.rs b/crates/burn-cubecl-fusion/src/matmul/mod.rs new file mode 100644 index 0000000000..0950d64c5c --- /dev/null +++ b/crates/burn-cubecl-fusion/src/matmul/mod.rs @@ -0,0 +1,8 @@ +pub mod builder; +pub mod optimization; + +pub(crate) mod args; +pub(crate) mod spec; +pub(crate) mod tune; + +pub use optimization::MatmulFallbackFn; diff --git a/crates/burn-cubecl/src/fusion/matmul/optimization.rs b/crates/burn-cubecl-fusion/src/matmul/optimization.rs similarity index 78% rename from crates/burn-cubecl/src/fusion/matmul/optimization.rs rename to crates/burn-cubecl-fusion/src/matmul/optimization.rs index 36e563c8be..6960529e89 100644 --- a/crates/burn-cubecl/src/fusion/matmul/optimization.rs +++ b/crates/burn-cubecl-fusion/src/matmul/optimization.rs @@ -1,19 +1,17 @@ use std::any::TypeId; +use std::sync::Arc; -use crate::fusion::elemwise::optimization::ElemwiseRunner; -use crate::fusion::on_write::ir::ElemwisePrecision; -use crate::kernel::matmul; -use crate::{fusion::CubeFusionHandle, CubeRuntime}; -use crate::{BoolElement, FloatElement}; +use crate::elemwise::optimization::ElemwiseRunner; +use crate::on_write::ir::ElemwisePrecision; +use crate::CubeFusionHandle; use burn_fusion::stream::Context; use burn_ir::{BinaryOpIr, TensorStatus}; -use burn_tensor::Shape; use cubecl::linalg::matmul::components; use cubecl::linalg::matmul::components::tile::accelerated::Accelerated; use cubecl::linalg::matmul::components::MatmulProblem; use cubecl::linalg::matmul::kernels::matmul::{ - MatmulSelector, SimplePipelinedSelector, SimpleSelector, SpecializedSelector, + DoubleBufferingSelector, MatmulSelector, SimpleSelector, SpecializedSelector, }; use cubecl::linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}; use cubecl::linalg::tensor::{matrix_layout, MatrixLayout}; @@ -21,7 +19,7 @@ use cubecl::{client::ComputeClient, prelude::*}; use half::{bf16, f16}; use serde::{Deserialize, Serialize}; -use crate::fusion::on_write::{ +use crate::on_write::{ ir::{Arg, ElemwiseConfig, GlobalArgsLaunch}, trace::{FuseOnWriteTrace, TraceRunner}, }; @@ -31,15 +29,24 @@ use super::spec::FusedMatmulSpec; use super::tune::fused_matmul_autotune; /// Fuse matmul operation followed by elemwise operations into a single kernel. -pub struct MatmulOptimization { +pub struct MatmulOptimization { trace: FuseOnWriteTrace, trace_fallback: FuseOnWriteTrace, pub(crate) client: ComputeClient, pub(crate) device: R::Device, pub(crate) len: usize, - pub(crate) matmul_standard: FusedMatmul, - pub(crate) matmul_pipelined: FusedMatmul, + pub(crate) matmul_simple: FusedMatmul, + pub(crate) matmul_double_buffering: FusedMatmul, pub(crate) matmul_specialized: FusedMatmul, + fallback: Arc>, +} + +pub trait MatmulFallbackFn: Send + Sync { + fn run( + &self, + lhs: (CubeFusionHandle, &[usize]), + rhs: (CubeFusionHandle, &[usize]), + ) -> CubeFusionHandle; } #[derive(Serialize, Deserialize, Debug)] @@ -47,13 +54,13 @@ pub struct MatmulOptimization { pub struct MatmulOptimizationState { trace: FuseOnWriteTrace, trace_fallback: FuseOnWriteTrace, - matmul_standard: FusedMatmul, - matmul_pipelined: FusedMatmul, + matmul_simple: FusedMatmul, + matmul_double_buffering: FusedMatmul, matmul_specialized: FusedMatmul, len: usize, } -impl MatmulOptimization { +impl MatmulOptimization { pub fn new( trace: FuseOnWriteTrace, trace_fallback: FuseOnWriteTrace, @@ -61,14 +68,15 @@ impl MatmulOptimization { device: R::Device, len: usize, matmul: FusedMatmul, + fallback: Arc>, ) -> Self { - let mut matmul_standard = matmul.clone(); + let mut matmul_simple = matmul.clone(); let mut matmul_specialized = matmul.clone(); - let mut matmul_pipelined = matmul; + let mut matmul_double_buffering = matmul; - matmul_standard.selector = FusedMatmulSelector::Standard; + matmul_simple.selector = FusedMatmulSelector::Simple; matmul_specialized.selector = FusedMatmulSelector::Specialized; - matmul_pipelined.selector = FusedMatmulSelector::Pipelined; + matmul_double_buffering.selector = FusedMatmulSelector::DoubleBuffering; Self { trace, @@ -76,13 +84,14 @@ impl MatmulOptimization { client, device, len, - matmul_standard, - matmul_pipelined, + matmul_simple, + matmul_double_buffering, matmul_specialized, + fallback, } } /// Execute the optimization. - pub fn execute(&mut self, context: &mut Context<'_, CubeFusionHandle>) { + pub fn execute(&mut self, context: &mut Context<'_, CubeFusionHandle>) { #[cfg(feature = "autotune")] fused_matmul_autotune::(self, context); @@ -98,16 +107,21 @@ impl MatmulOptimization { } /// Create an optimization from its [state](MatmulOptimizationState). - pub fn from_state(device: &R::Device, state: MatmulOptimizationState) -> Self { + pub fn from_state( + device: &R::Device, + state: MatmulOptimizationState, + fallback: Arc>, + ) -> Self { Self { trace: state.trace, trace_fallback: state.trace_fallback, len: state.len, client: R::client(device), device: device.clone(), - matmul_standard: state.matmul_standard.clone(), + matmul_simple: state.matmul_simple.clone(), matmul_specialized: state.matmul_specialized.clone(), - matmul_pipelined: state.matmul_pipelined.clone(), + matmul_double_buffering: state.matmul_double_buffering.clone(), + fallback, } } @@ -116,9 +130,9 @@ impl MatmulOptimization { MatmulOptimizationState { trace: self.trace.clone(), trace_fallback: self.trace_fallback.clone(), - matmul_standard: self.matmul_standard.clone(), + matmul_simple: self.matmul_simple.clone(), matmul_specialized: self.matmul_specialized.clone(), - matmul_pipelined: self.matmul_pipelined.clone(), + matmul_double_buffering: self.matmul_double_buffering.clone(), len: self.len, } } @@ -128,7 +142,7 @@ impl MatmulOptimization { self.trace_fallback.outputs.len() } - pub fn execute_standard_fused( + pub fn execute_simple_fused( &self, context: &mut Context<'_, CubeFusionHandle>, ) -> Result<(), FusedMatmulError> { @@ -136,11 +150,11 @@ impl MatmulOptimization { &self.client, &self.device, context, - &self.matmul_standard, + &self.matmul_simple, ) } - pub fn execute_specialized_fused( + pub fn execute_specialized_fused( &self, context: &mut Context<'_, CubeFusionHandle>, ) -> Result<(), FusedMatmulError> { @@ -152,7 +166,7 @@ impl MatmulOptimization { ) } - pub fn execute_pipelined_fused( + pub fn execute_double_buffering_fused( &self, context: &mut Context<'_, CubeFusionHandle>, ) -> Result<(), FusedMatmulError> { @@ -160,64 +174,40 @@ impl MatmulOptimization { &self.client, &self.device, context, - &self.matmul_pipelined, + &self.matmul_double_buffering, ) } - pub fn execute_fallback( - &self, - context: &mut Context<'_, CubeFusionHandle>, - ) { - match self.matmul_standard.lhs.precision() { - ElemwisePrecision::F32 => self.run_fallback::(context), - ElemwisePrecision::F16 => self.run_fallback::(context), - ElemwisePrecision::BF16 => self.run_fallback::(context), - _ => panic!("Unsupported precision"), - } - } - - fn run_fallback( + pub fn execute_fallback( &self, context: &mut Context<'_, CubeFusionHandle>, ) { let (out_tensor, out_desc) = { let lhs = context .tensors - .get(&self.matmul_standard.op.lhs.id) + .get(&self.matmul_simple.op.lhs.id) .unwrap() .clone(); let rhs = context .tensors - .get(&self.matmul_standard.op.rhs.id) + .get(&self.matmul_simple.op.rhs.id) .unwrap() .clone(); let out = context .tensors - .get(&self.matmul_standard.op.out.id) + .get(&self.matmul_simple.op.out.id) .unwrap() .clone(); let lhs_handle = context.handles.get_handle(&lhs.id, &TensorStatus::ReadOnly); let rhs_handle = context.handles.get_handle(&rhs.id, &TensorStatus::ReadOnly); + let out_handle = self + .fallback + .run((lhs_handle, &lhs.shape), (rhs_handle, &rhs.shape)); - let lhs_tensor = lhs_handle.into_tensor(Shape { - dims: lhs.shape.clone(), - }); - let rhs_tensor = rhs_handle.into_tensor(Shape { - dims: rhs.shape.clone(), - }); - let out_tensor = matmul::matmul::( - lhs_tensor, - rhs_tensor, - None, - matmul::MatmulStrategy::default(), - ) - .unwrap(); - (out_tensor, out) + (out_handle, out) }; - context - .handles - .register_handle(out_desc.id, CubeFusionHandle::from(out_tensor)); + context.handles.register_handle(out_desc.id, out_tensor); self.trace_fallback .run::(&self.client, &self.device, context, &ElemwiseRunner) @@ -228,8 +218,8 @@ impl MatmulOptimization { #[derive(Default, Clone, Serialize, Deserialize, Debug)] pub enum FusedMatmulSelector { #[default] - Standard, - Pipelined, + Simple, + DoubleBuffering, Specialized, } @@ -254,7 +244,7 @@ impl From for FusedMatmulError { } } -impl TraceRunner for FusedMatmul { +impl TraceRunner for FusedMatmul { type Error = FusedMatmulError; fn run<'a>( @@ -276,7 +266,7 @@ impl TraceRunner for FusedMatmul { } impl FusedMatmul { - fn matmul_fused<'a, R: CubeRuntime, EG: Numeric>( + fn matmul_fused<'a, R: Runtime, EG: Numeric>( &'a self, client: &'a ComputeClient, inputs: GlobalArgsLaunch<'a, R>, @@ -360,7 +350,7 @@ impl FusedMatmul { }; match self.selector { - FusedMatmulSelector::Standard => { + FusedMatmulSelector::Simple => { match matmul_launch_kernel::>( client, FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), @@ -372,8 +362,8 @@ impl FusedMatmul { Err(err) => Err(FusedMatmulError::LaunchError(err)), } } - FusedMatmulSelector::Pipelined => { - match matmul_launch_kernel::>( + FusedMatmulSelector::DoubleBuffering => { + match matmul_launch_kernel::>( client, FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), outputs, diff --git a/crates/burn-cubecl/src/fusion/matmul/spec.rs b/crates/burn-cubecl-fusion/src/matmul/spec.rs similarity index 100% rename from crates/burn-cubecl/src/fusion/matmul/spec.rs rename to crates/burn-cubecl-fusion/src/matmul/spec.rs diff --git a/crates/burn-cubecl/src/fusion/matmul/tune.rs b/crates/burn-cubecl-fusion/src/matmul/tune.rs similarity index 69% rename from crates/burn-cubecl/src/fusion/matmul/tune.rs rename to crates/burn-cubecl-fusion/src/matmul/tune.rs index 9d7eb11d8f..02b6e37476 100644 --- a/crates/burn-cubecl/src/fusion/matmul/tune.rs +++ b/crates/burn-cubecl-fusion/src/matmul/tune.rs @@ -1,15 +1,12 @@ use crate::{ - fusion::{ - tune::{TuneContext, TuneInput}, - CubeFusionHandle, - }, - kernel::matmul::MatmulAutotuneKey, - BoolElement, CubeRuntime, CubeTuneId, + tune::{TuneContext, TuneInput}, + CubeFusionHandle, }; use burn_fusion::stream::Context; use cubecl::{ + linalg::matmul::tune_key::MatmulAutotuneKey, tune::{local_tuner, LocalTuner, TunableSet}, - AutotuneKey, + AutotuneKey, CubeElement, CubeTuneId, Runtime, }; use serde::{Deserialize, Serialize}; @@ -25,16 +22,16 @@ pub struct FusedMatmulAutotuneKey { } /// Executes autotune on matmul operations -pub fn fused_matmul_autotune( +pub fn fused_matmul_autotune( optimization: &MatmulOptimization, context: &mut Context>, ) { static TUNER: LocalTuner = local_tuner!(); let tunables = TunableSet::new(create_key::, input_gen::) - .with_tunable(tune_standard_fused::) + .with_tunable(tune_simple_fused::) .with_tunable(tune_specialized_fused::) - .with_tunable(tune_pipelined_fused::) + .with_tunable(tune_double_buffering_fused::) .with_tunable(tune_fallback::); TUNER.execute( @@ -45,7 +42,7 @@ pub fn fused_matmul_autotune( ); } -pub(crate) fn create_key( +pub(crate) fn create_key( input: &TuneInput>, ) -> FusedMatmulAutotuneKey { let opt = input.optimization(); @@ -54,41 +51,37 @@ pub(crate) fn create_key( TuneContext::Fork(_) => panic!("Not supported when generating key"), }; - let lhs = context.tensors.get(&opt.matmul_standard.op.lhs.id).unwrap(); - let rhs = context.tensors.get(&opt.matmul_standard.op.rhs.id).unwrap(); - let out = context.tensors.get(&opt.matmul_standard.op.out.id).unwrap(); + let lhs = context.tensors.get(&opt.matmul_simple.op.lhs.id).unwrap(); + let rhs = context.tensors.get(&opt.matmul_simple.op.rhs.id).unwrap(); + let out = context.tensors.get(&opt.matmul_simple.op.out.id).unwrap(); - let key = MatmulAutotuneKey::from_shape( - &lhs.shape.clone().into(), - &rhs.shape.clone().into(), - out.dtype, - ); + let key = MatmulAutotuneKey::from_shape(&lhs.shape, &rhs.shape, out.dtype.into()); FusedMatmulAutotuneKey::new(key, opt.num_output_buffers(), opt.num_ops_fused()) } -fn input_gen( +fn input_gen( _key: &FusedMatmulAutotuneKey, input: &TuneInput>, ) -> TuneInput> { input.clone() } -fn tune_standard_fused( +fn tune_simple_fused( input: TuneInput>, ) -> Result<(), String> { let optimization = input.optimization(); let context = input.context(); match context { - TuneContext::Original(context) => optimization.execute_standard_fused::(context), + TuneContext::Original(context) => optimization.execute_simple_fused::(context), TuneContext::Fork(mut context_owned) => { - optimization.execute_standard_fused::(&mut context_owned.as_context()) + optimization.execute_simple_fused::(&mut context_owned.as_context()) } } .map_err(|e| format!("{e:?}")) } -fn tune_specialized_fused( +fn tune_specialized_fused( input: TuneInput>, ) -> Result<(), String> { let optimization = input.optimization(); @@ -103,22 +96,24 @@ fn tune_specialized_fused( .map_err(|e| format!("{e:?}")) } -fn tune_pipelined_fused( +fn tune_double_buffering_fused( input: TuneInput>, ) -> Result<(), String> { let optimization = input.optimization(); let context = input.context(); match context { - TuneContext::Original(context) => optimization.execute_pipelined_fused::(context), + TuneContext::Original(context) => { + optimization.execute_double_buffering_fused::(context) + } TuneContext::Fork(mut context_owned) => { - optimization.execute_pipelined_fused::(&mut context_owned.as_context()) + optimization.execute_double_buffering_fused::(&mut context_owned.as_context()) } } .map_err(|e| format!("{e:?}")) } -fn tune_fallback( +fn tune_fallback( input: TuneInput>, ) -> Result<(), String> { let optimization = input.optimization(); diff --git a/crates/burn-cubecl-fusion/src/on_write/base.rs b/crates/burn-cubecl-fusion/src/on_write/base.rs new file mode 100644 index 0000000000..32aada0ba1 --- /dev/null +++ b/crates/burn-cubecl-fusion/src/on_write/base.rs @@ -0,0 +1 @@ +pub(crate) const DYN_ELEM_ID: u8 = u8::MAX; diff --git a/crates/burn-cubecl/src/fusion/on_write/builder.rs b/crates/burn-cubecl-fusion/src/on_write/builder.rs similarity index 100% rename from crates/burn-cubecl/src/fusion/on_write/builder.rs rename to crates/burn-cubecl-fusion/src/on_write/builder.rs diff --git a/crates/burn-cubecl-fusion/src/on_write/io.rs b/crates/burn-cubecl-fusion/src/on_write/io.rs new file mode 100644 index 0000000000..d810e09972 --- /dev/null +++ b/crates/burn-cubecl-fusion/src/on_write/io.rs @@ -0,0 +1,432 @@ +use super::{ir::*, tensor::GlobalTensor, DYN_ELEM_ID}; +use cubecl::{ + ir::{ExpandElement, Variable}, + prelude::*, + unexpanded, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] +pub enum Transform { + Reshape(Sequence), + SwapDim(u32, u32), +} + +#[cube] +/// Read the value from the [arg](Arg) and cast it to the generic cube primitive. +pub fn read( + inputs: &GlobalArgs, + outputs: &GlobalArgs, + locals: &LocalArgs, + ref_pos: u32, + #[comptime] arg: Arg, + #[comptime] config: &ElemwiseConfig, +) -> Line { + match arg { + Arg::Input(pos, _precision, layout) => { + read_input(inputs, outputs, pos, ref_pos, layout, config, None) + } + Arg::Output(pos, _precision, layout) => { + read_output(inputs, outputs, pos, ref_pos, layout, config) + } + Arg::Local(pos, precision) => match comptime![precision] { + ElemwisePrecision::F32 => Line::cast_from(locals.l_f32.find(pos)), + ElemwisePrecision::F16 => Line::cast_from(locals.l_f16.find(pos)), + ElemwisePrecision::BF16 => Line::cast_from(locals.l_bf16.find(pos)), + ElemwisePrecision::U64 => Line::cast_from(locals.l_u64.find(pos)), + ElemwisePrecision::U32 => Line::cast_from(locals.l_u32.find(pos)), + ElemwisePrecision::U16 => Line::cast_from(locals.l_u16.find(pos)), + ElemwisePrecision::U8 => Line::cast_from(locals.l_u8.find(pos)), + ElemwisePrecision::I64 => Line::cast_from(locals.l_i64.find(pos)), + ElemwisePrecision::I32 => Line::cast_from(locals.l_i32.find(pos)), + ElemwisePrecision::I16 => Line::cast_from(locals.l_i16.find(pos)), + ElemwisePrecision::I8 => Line::cast_from(locals.l_i8.find(pos)), + ElemwisePrecision::Bool => Line::cast_from(locals.l_bool.find(pos)), + }, + Arg::Scalar(..) => { + let scalar = read_scalar::(inputs, arg); + Line::new(scalar) + } + Arg::ScalarShape(_) => { + let scalar = read_scalar_shape(inputs, arg); + Line::cast_from(scalar) + } + Arg::Literal(val, _precision) => Line::new(from_const_int::(val)), + Arg::InputReshaped { + original, shape, .. + } => match comptime![original.as_ref().clone()] { + Arg::Input(pos, _precision, layout) => read_input( + inputs, + outputs, + pos, + ref_pos, + layout, + config, + comptime![Some(Transform::Reshape(shape))], + ), + _ => comptime![panic!("Only input can be reshaped")], + }, + Arg::InputSwapDims { original, dims, .. } => match comptime![original.as_ref().clone()] { + Arg::Input(pos, _precision, layout) => read_input( + inputs, + outputs, + pos, + ref_pos, + layout, + config, + comptime![Some(Transform::SwapDim(dims.0, dims.1))], + ), + _ => comptime![panic!("Only input can be reshaped")], + }, + } +} + +#[cube] +pub fn read_scalar(inputs: &GlobalArgs, #[comptime] arg: Arg) -> C { + match arg { + Arg::Scalar(pos, _precision) => { + let scalar = inputs.scalars.index(pos); + scalar.read::() + } + _ => comptime![panic!("Not a scalar")], + } +} + +#[cube] +pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: Arg) -> u32 { + match arg { + Arg::ScalarShape(pos) => { + let offset = comptime![inputs.scalars.len() - pos - 1]; + let scalar = inputs.scalars.index(offset); + + scalar.as_u32() + } + _ => comptime![panic!("Not a scalar shape")], + } +} + +#[cube] +pub fn read_input( + inputs: &GlobalArgs, + outputs: &GlobalArgs, + #[comptime] pos: u32, + ref_pos: u32, + #[comptime] layout: LayoutInfo, + #[comptime] config: &ElemwiseConfig, + #[comptime] transform: Option, +) -> Line { + let tensor = inputs.tensors.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) + } + }; + Line::cast_from(tensor.tensor[offset]) +} + +#[cube] +pub fn read_output( + inputs: &GlobalArgs, + outputs: &GlobalArgs, + pos: u32, + ref_pos: u32, + #[comptime] layout: LayoutInfo, + #[comptime] config: &ElemwiseConfig, +) -> Line { + let tensor = outputs.tensors.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, None, config, None), + }; + Line::cast_from(tensor.tensor[offset]) +} + +#[cube] +/// Write the given value at the [arg](Arg) position. +pub fn write( + inputs: &GlobalArgs, + outputs: &mut GlobalArgs, + locals: &mut LocalArgs, + ref_pos: u32, + value: Line, + #[comptime] arg: Arg, + #[comptime] config: &ElemwiseConfig, +) { + match arg { + Arg::Output(pos, precision, layout) => { + let tensor = outputs.tensors.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, None, config, None) + } + }; + let tensor = outputs.tensors.index_mut(pos); + set_polyfill::>(comptime![precision.into_elem()]); + tensor.tensor[offset] = Line::cast_from(value); + } + Arg::Local(pos, precision) => match comptime![precision] { + ElemwisePrecision::F32 => locals.l_f32.insert(pos, Line::cast_from(value)), + ElemwisePrecision::F16 => locals.l_f16.insert(pos, Line::cast_from(value)), + ElemwisePrecision::BF16 => locals.l_bf16.insert(pos, Line::cast_from(value)), + ElemwisePrecision::U64 => locals.l_u64.insert(pos, Line::cast_from(value)), + ElemwisePrecision::U32 => locals.l_u32.insert(pos, Line::cast_from(value)), + ElemwisePrecision::U16 => locals.l_u16.insert(pos, Line::cast_from(value)), + ElemwisePrecision::U8 => locals.l_u8.insert(pos, Line::cast_from(value)), + ElemwisePrecision::I64 => locals.l_i64.insert(pos, Line::cast_from(value)), + ElemwisePrecision::I32 => locals.l_i32.insert(pos, Line::cast_from(value)), + ElemwisePrecision::I16 => locals.l_i16.insert(pos, Line::cast_from(value)), + ElemwisePrecision::I8 => locals.l_i8.insert(pos, Line::cast_from(value)), + ElemwisePrecision::Bool => locals.l_bool.insert(pos, Line::cast_from(value)), + }, + _ => comptime![panic!("Can't write into inputs and scalars")], + } +} + +#[cube] +pub(crate) fn global_offset( + inputs: &GlobalArgs, + outputs: &GlobalArgs, + index: u32, + #[comptime] arg: Arg, + #[comptime] range: Option<(u32, u32)>, + #[comptime] config: &ElemwiseConfig, +) -> u32 { + match arg { + Arg::Input(pos, _precision, _layout) => { + let tensor = inputs.tensors.index(pos); + get_offset(inputs, outputs, tensor, index, range, config, None) + } + Arg::Output(pos, _precision, _layout) => { + let tensor = outputs.tensors.index(pos); + get_offset(inputs, outputs, tensor, index, range, config, None) + } + _ => todo!(), + } +} + +#[cube] +fn get_offset( + inputs: &GlobalArgs, + outputs: &GlobalArgs, + tensor: &GlobalTensor, + ref_pos: u32, + #[comptime] range: Option<(u32, u32)>, + #[comptime] config: &ElemwiseConfig, + #[comptime] transform: Option, +) -> u32 { + match comptime![config.ref_layout.clone()] { + Arg::Input(index, _precision, _) => { + let layout = inputs.tensors.index(index); + index_offset_with_layout( + inputs, + tensor, + layout, + ref_pos, + range, + config.rank, + transform, + ) + } + Arg::Output(index, _precision, _) => { + let layout = outputs.tensors.index(index); + index_offset_with_layout( + inputs, + tensor, + layout, + ref_pos, + range, + config.rank, + transform, + ) + } + _ => comptime![panic!("Invalid ref layout.")], + } +} + +#[cube] +pub fn global_line_size(global: &GlobalArgs, #[comptime] pos: u32) -> u32 { + let tensor = global.tensors.index(pos); + u32::cast_from(tensor.tensor.line_size()) +} + +#[cube] +pub fn global_length(global: &GlobalArgs, #[comptime] pos: u32) -> u32 { + let tensor = global.tensors.index(pos); + u32::cast_from(tensor.tensor.len()) +} + +#[cube] +pub fn global_rank(global: &GlobalArgs, #[comptime] pos: u32) -> u32 { + let tensor = global.tensors.index(pos); + tensor.tensor.rank() +} + +#[cube] +pub fn global_shape(global: &GlobalArgs, dim: u32, #[comptime] pos: u32) -> u32 { + let tensor = global.tensors.index(pos); + tensor.tensor.shape(dim) +} + +#[cube] +pub fn global_stride(global: &GlobalArgs, dim: u32, #[comptime] pos: u32) -> u32 { + let tensor = global.tensors.index(pos); + tensor.tensor.stride(dim) +} + +#[cube] +fn index_offset_with_layout( + inputs: &GlobalArgs, + tensor: &GlobalTensor, + layout: &GlobalTensor, + index: u32, + #[comptime] range: Option<(u32, u32)>, + #[comptime] rank: u32, + #[comptime] transform: Option, +) -> u32 { + match comptime![transform.clone()] { + Some(Transform::Reshape(shape)) => { + comptime![assert!( + range.is_none(), + "Can't get a range on a reshaped tensor." + )]; + let index = reshaped_index(inputs, &layout.tensor, index, rank, shape); + reshaped_index_to_original_index(&tensor.tensor, index, rank) + } + Some(Transform::SwapDim(dim1, dim2)) => { + let (start, end) = comptime! {match range { + Some(range) => range, + None => (0u32, rank), + }}; + + let offset_ref = index * layout.tensor.line_size(); + let mut offset = 0u32; + + #[unroll] + for i in start..end { + let index = comptime![swap_dims_transform(&i, (dim1, dim2))]; + let ogwl = offset_ref / layout.tensor.stride(i); + offset += ogwl % tensor.tensor.shape(index) * tensor.tensor.stride(index); + } + + offset / tensor.tensor.line_size() + } + None => { + let (start, end) = comptime! {match range { + Some(range) => range, + None => (0u32, rank), + }}; + + let offset_ref = index * layout.tensor.line_size(); + let mut offset = 0u32; + + for i in start..end { + let ogwl = offset_ref / layout.tensor.stride(i); + offset += ogwl % tensor.tensor.shape(i) * tensor.tensor.stride(i); + } + + offset / tensor.tensor.line_size() + } + } +} + +fn swap_dims_transform(i: &I, dims: (u32, u32)) -> u32 { + let i_cloned: I = i.clone(); + let i = i_cloned.value().as_const().unwrap().as_u32(); + + if i == dims.0 { + dims.1 + } else if i == dims.1 { + dims.0 + } else { + i + } +} + +#[cube] +fn reshaped_index( + inputs: &GlobalArgs, + layout: &Tensor>>, + index: u32, + #[comptime] rank: u32, + #[comptime] shape: Sequence, +) -> u32 { + let index = index * layout.line_size(); + + let mut offset = 0u32; + let mut stride_curr = 1u32; + + #[unroll] + for r in 0..rank { + let i = comptime![reverse_index(rank, r)]; + let arg = comptime![shape.index(i.clone())]; + let shape_i = read_scalar_shape(inputs, comptime![arg.clone()]); + + let ogwl = index / layout.stride(i); + offset += ogwl % shape_i * stride_curr; + + stride_curr *= shape_i; + } + + offset +} + +#[cube] +fn reshaped_index_to_original_index( + original: &Tensor>, + index_reshaped: u32, + #[comptime] rank: u32, +) -> u32 { + let mut remaining = index_reshaped; + let mut offset = 0; + + #[unroll] + for r in 0..rank { + let i = comptime![reverse_index(rank, r)]; + let shape = original.shape(comptime![i.clone()]); + let stride = original.stride(i); + + let coordinate = remaining % shape; + + remaining /= shape; + offset += coordinate * stride; + } + + offset / original.line_size() +} + +fn reverse_index>>( + rank: u32, + iter: Elem, +) -> ExpandElementTyped { + let elem = iter.into(); + let elem = elem.constant().map(|cons| cons.as_u32()).unwrap(); + let result = rank - elem - 1; + let scalar: Variable = result.into(); + let expand: ExpandElement = ExpandElement::Plain(scalar); + + expand.into() +} + +/// Generic way to construct any [`CubePrimitive`] from an int. Used for fusion. +fn from_const_int(_value: u32) -> C { + unexpanded!() +} + +mod from_const_int { + use cubecl::ir::{ExpandElement, Scope, Variable}; + + use cubecl::prelude::ExpandElementTyped; + + use super::CubePrimitive; + + pub fn expand(scope: &mut Scope, value: u32) -> ExpandElementTyped { + let constant: ExpandElement = value.into(); + let constant_c = constant.as_const().unwrap().cast_to(C::as_elem(scope)); + ExpandElement::Plain(Variable::constant(constant_c)).into() + } +} diff --git a/crates/burn-cubecl/src/fusion/on_write/ir.rs b/crates/burn-cubecl-fusion/src/on_write/ir.rs similarity index 66% rename from crates/burn-cubecl/src/fusion/on_write/ir.rs rename to crates/burn-cubecl-fusion/src/on_write/ir.rs index e692cf1507..5e3c4c257c 100644 --- a/crates/burn-cubecl/src/fusion/on_write/ir.rs +++ b/crates/burn-cubecl-fusion/src/on_write/ir.rs @@ -4,6 +4,11 @@ use cubecl::prelude::*; use half::{bf16, f16}; use serde::{Deserialize, Serialize}; +use super::{ + tensor::{GlobalScalar, GlobalTensor}, + DYN_ELEM_ID, +}; + #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] /// Argument to an [elemwise operation](ElemwiseOp). pub enum Arg { @@ -109,6 +114,37 @@ pub enum ElemwiseOp { }, } +impl ElemwiseOp { + /// Element type used for the computation. + pub(crate) fn cmp_elem(&self) -> Elem { + match self { + ElemwiseOp::Add(op) => op.lhs.precision().into_elem(), + ElemwiseOp::Sub(op) => op.lhs.precision().into_elem(), + ElemwiseOp::Mul(op) => op.lhs.precision().into_elem(), + ElemwiseOp::Div(op) => op.lhs.precision().into_elem(), + ElemwiseOp::Powf(op) => op.lhs.precision().into_elem(), + ElemwiseOp::Abs(op) => op.out.precision().into_elem(), + ElemwiseOp::Exp(op) => op.out.precision().into_elem(), + ElemwiseOp::Log(op) => op.out.precision().into_elem(), + ElemwiseOp::Log1p(op) => op.out.precision().into_elem(), + ElemwiseOp::Cos(op) => op.out.precision().into_elem(), + ElemwiseOp::Sin(op) => op.out.precision().into_elem(), + ElemwiseOp::Tanh(op) => op.out.precision().into_elem(), + ElemwiseOp::Erf(op) => op.out.precision().into_elem(), + ElemwiseOp::Recip(op) => op.out.precision().into_elem(), + ElemwiseOp::Assign(op) => op.out.precision().into_elem(), + ElemwiseOp::Equal(op) => op.lhs.precision().into_elem(), + ElemwiseOp::Lower(op) => op.lhs.precision().into_elem(), + ElemwiseOp::Greater(op) => op.lhs.precision().into_elem(), + ElemwiseOp::LowerEqual(op) => op.lhs.precision().into_elem(), + ElemwiseOp::GreaterEqual(op) => op.lhs.precision().into_elem(), + ElemwiseOp::ConditionalAssign { out, .. } => out.precision().into_elem(), + ElemwiseOp::Gather { output, .. } => output.precision().into_elem(), + ElemwiseOp::Select { output, .. } => output.precision().into_elem(), + } + } +} + #[derive(CubeLaunch)] pub struct ReshapedTensor { #[cube(comptime)] @@ -120,58 +156,21 @@ pub struct ReshapedTensor { #[derive(CubeLaunch, Default)] /// Global arguments that are used for fusing [element wise operations](ElemwiseOp). pub struct GlobalArgs { - pub t_f32: Sequence>>, - pub t_f16: Sequence>>, - pub t_bf16: Sequence>>, - pub t_i64: Sequence>>, - pub t_i32: Sequence>>, - pub t_i16: Sequence>>, - pub t_i8: Sequence>>, - pub t_u64: Sequence>>, - pub t_u32: Sequence>>, - pub t_u16: Sequence>>, - pub t_u8: Sequence>>, - pub s_f32: Sequence, - pub s_f16: Sequence, - pub s_bf16: Sequence, - pub s_i64: Sequence, - pub s_i32: Sequence, - pub s_i16: Sequence, - pub s_i8: Sequence, - pub s_u64: Sequence, - pub s_u32: Sequence, - pub s_u16: Sequence, - pub s_u8: Sequence, + pub tensors: Sequence, + pub scalars: Sequence, } impl Default for GlobalArgsLaunch<'_, R> { fn default() -> Self { - Self::new( - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - ) + Self { + tensors: Default::default(), + scalars: Default::default(), + _phantom_runtime: std::marker::PhantomData, + _phantom_a: std::marker::PhantomData, + } } } + impl GlobalArgsLaunch<'_, R> { /// Get the shape of the given [argument](Arg). /// @@ -219,34 +218,8 @@ impl GlobalArgsLaunch<'_, R> { /// If the argument isn't a global input or output tensor. pub fn resolve_arg(&self, arg: &Arg) -> &TensorArg<'_, R> { match arg { - Arg::Input(pos, precision, _) => match precision { - ElemwisePrecision::F32 => &self.t_f32.values[*pos as usize], - ElemwisePrecision::F16 => &self.t_f16.values[*pos as usize], - ElemwisePrecision::BF16 => &self.t_bf16.values[*pos as usize], - ElemwisePrecision::I64 => &self.t_i64.values[*pos as usize], - ElemwisePrecision::I32 => &self.t_i32.values[*pos as usize], - ElemwisePrecision::I16 => &self.t_i16.values[*pos as usize], - ElemwisePrecision::I8 => &self.t_i8.values[*pos as usize], - ElemwisePrecision::U64 => &self.t_u64.values[*pos as usize], - ElemwisePrecision::U32 => &self.t_u32.values[*pos as usize], - ElemwisePrecision::U16 => &self.t_u16.values[*pos as usize], - ElemwisePrecision::U8 => &self.t_u8.values[*pos as usize], - ElemwisePrecision::Bool => panic!("Unsupported yet"), - }, - Arg::Output(pos, precision, _) => match precision { - ElemwisePrecision::F32 => &self.t_f32.values[*pos as usize], - ElemwisePrecision::F16 => &self.t_f16.values[*pos as usize], - ElemwisePrecision::BF16 => &self.t_bf16.values[*pos as usize], - ElemwisePrecision::I64 => &self.t_i64.values[*pos as usize], - ElemwisePrecision::I32 => &self.t_i32.values[*pos as usize], - ElemwisePrecision::I16 => &self.t_i16.values[*pos as usize], - ElemwisePrecision::I8 => &self.t_i8.values[*pos as usize], - ElemwisePrecision::U64 => &self.t_u64.values[*pos as usize], - ElemwisePrecision::U32 => &self.t_u32.values[*pos as usize], - ElemwisePrecision::U16 => &self.t_u16.values[*pos as usize], - ElemwisePrecision::U8 => &self.t_u8.values[*pos as usize], - ElemwisePrecision::Bool => panic!("Unsupported yet"), - }, + Arg::Input(pos, _, _) => &self.tensors.values[*pos as usize].tensor, + Arg::Output(pos, _, _) => &self.tensors.values[*pos as usize].tensor, _ => panic!("Only input & output can have a shape"), } } @@ -270,6 +243,33 @@ pub struct LocalArgs { pub l_bool: Registry>, } +#[cube] +impl LocalArgs { + pub fn new() -> LocalArgs { + LocalArgs { + l_f32: Registry::>::new(), + l_f16: Registry::>::new(), + l_bf16: Registry::>::new(), + l_i64: Registry::>::new(), + l_i32: Registry::>::new(), + l_i16: Registry::>::new(), + l_i8: Registry::>::new(), + l_u64: Registry::>::new(), + l_u32: Registry::>::new(), + l_u16: Registry::>::new(), + l_u8: Registry::>::new(), + l_bool: Registry::>::new(), + } + } +} + +#[derive(CubeType, Clone)] +/// Keep track of all local variables that are used as argument in fused +/// [element wise operations](ElemwiseOp). +pub struct LocalArgs2 { + pub scalars: Registry>>, +} + #[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] /// Unary [element wise operation](ElemwiseOp) arguments. pub struct UnaryElemwiseArgs { @@ -330,6 +330,24 @@ impl From for ElemwisePrecision { } } } +impl ElemwisePrecision { + pub fn into_elem(self) -> Elem { + match self { + ElemwisePrecision::F32 => Elem::Float(cubecl::ir::FloatKind::F32), + ElemwisePrecision::F16 => Elem::Float(cubecl::ir::FloatKind::F16), + ElemwisePrecision::BF16 => Elem::Float(cubecl::ir::FloatKind::BF16), + ElemwisePrecision::I64 => Elem::Int(cubecl::ir::IntKind::I64), + ElemwisePrecision::I32 => Elem::Int(cubecl::ir::IntKind::I32), + ElemwisePrecision::I16 => Elem::Int(cubecl::ir::IntKind::I16), + ElemwisePrecision::I8 => Elem::Int(cubecl::ir::IntKind::I8), + ElemwisePrecision::U64 => Elem::UInt(cubecl::ir::UIntKind::U64), + ElemwisePrecision::U32 => Elem::UInt(cubecl::ir::UIntKind::U32), + ElemwisePrecision::U16 => Elem::UInt(cubecl::ir::UIntKind::U16), + ElemwisePrecision::U8 => Elem::UInt(cubecl::ir::UIntKind::U8), + ElemwisePrecision::Bool => Elem::Bool, + } + } +} impl From for ElemwisePrecision { fn from(value: DType) -> Self { diff --git a/crates/burn-cubecl-fusion/src/on_write/kernel.rs b/crates/burn-cubecl-fusion/src/on_write/kernel.rs new file mode 100644 index 0000000000..5c431cfcfd --- /dev/null +++ b/crates/burn-cubecl-fusion/src/on_write/kernel.rs @@ -0,0 +1,609 @@ +use crate::on_write::DYN_ELEM_ID; + +use super::io::*; +use super::ir::*; +use cubecl::prelude::*; + +#[cube] +/// Fuse element-wise operations at the given write position. +/// +/// You can start by writing some elements using `write_values` and `write_args`. +pub fn fuse_on_write( + inputs: &GlobalArgs, + outputs: &mut GlobalArgs, + write_pos: u32, + write_values: Registry>, + #[comptime] write_args: Sequence, + #[comptime] config: &ElemwiseConfig, +) { + let mut locals = LocalArgs::new(); + + // Write the values given as arguments. + #[unroll] + for i in 0..write_args.len() { + let arg = comptime![write_args.index(i).clone()]; + let val = write_values.find(comptime![arg.clone()]); + + write::(inputs, outputs, &mut locals, write_pos, val, arg, config); + } + + #[unroll] + for index in 0..config.ops.len() { + let op = comptime! { config.ops.index(index).clone() }; + set_polyfill::>(comptime![op.cmp_elem()]); + + match op { + ElemwiseOp::Add(op) => add::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Div(op) => div::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Sub(op) => sub::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Mul(op) => mul::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Powf(op) => powf::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Erf(op) => erf::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Abs(op) => abs::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Log(op) => log::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Log1p(op) => log1p::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Recip(op) => recip::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Assign(op) => assign::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Exp(op) => exp::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Cos(op) => cos::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Sin(op) => sin::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Tanh(op) => tanh::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Equal(op) => equal::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Greater(op) => greater::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::GreaterEqual(op) => greater_equal::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::Lower(op) => lower::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::LowerEqual(op) => lower_equal::>( + inputs, + outputs, + &mut locals, + write_pos, + op, + config, + ), + ElemwiseOp::ConditionalAssign { + cond, + lhs, + rhs, + out, + } => conditional_assign::>( + inputs, + outputs, + &mut locals, + write_pos, + cond, + lhs, + rhs, + out, + config, + ), + ElemwiseOp::Gather { + input, + indices, + output, + dim, + } => gather::>( + inputs, + outputs, + &mut locals, + write_pos, + dim, + input, + indices, + output, + config, + ), + ElemwiseOp::Select { + input, + indices, + output, + dim, + } => select_indices::>( + inputs, + outputs, + &mut locals, + write_pos, + dim, + input, + indices, + output, + config, + ), + } + } +} + +macro_rules! binary_op { + ($ident:ident, $op:tt) => { + #[cube] + fn $ident( + inputs: &GlobalArgs, + outputs: &mut GlobalArgs, + locals: &mut LocalArgs, + write_pos: u32, + #[comptime] op: BinaryElemwiseArgs, + #[comptime] config: &ElemwiseConfig, + ) { + let lhs = read::(inputs, outputs, &locals, write_pos, op.lhs, config); + let rhs = read::(inputs, outputs, &locals, write_pos, op.rhs, config); + let result = lhs $op rhs; + + write::(inputs, outputs, locals, write_pos, result, op.out, config); + } + }; +} + +macro_rules! binary_func { + ($ident:ident, $func:expr, $c:tt) => { + #[cube] + fn $ident( + inputs: &GlobalArgs, + outputs: &mut GlobalArgs, + locals: &mut LocalArgs, + write_pos: u32, + #[comptime] op: BinaryElemwiseArgs, + #[comptime] config: &ElemwiseConfig, + ) { + let lhs = read::(inputs, outputs, &locals, write_pos, op.lhs, config); + let rhs = read::(inputs, outputs, &locals, write_pos, op.rhs, config); + let result = $func(lhs, rhs); + + write::(inputs, outputs, locals, write_pos, result, op.out, config); + } + }; +} + +macro_rules! comparison_op { + ($ident:ident, $op:tt) => { + #[cube] + fn $ident( + inputs: &GlobalArgs, + outputs: &mut GlobalArgs, + locals: &mut LocalArgs, + write_pos: u32, + #[comptime] op: BinaryElemwiseArgs, + #[comptime] config: &ElemwiseConfig, + ) { + let lhs = read::(inputs, outputs, &locals, write_pos, op.lhs, config); + let rhs = read::(inputs, outputs, &locals, write_pos, op.rhs, config); + let result = Line::new(lhs $op rhs); + + write::(inputs, outputs, locals, write_pos, result, op.out, config); + } + }; +} + +macro_rules! unary_func { + ($ident:ident, $func:expr, $c:tt) => { + #[cube] + fn $ident( + inputs: &GlobalArgs, + outputs: &mut GlobalArgs, + locals: &mut LocalArgs, + write_pos: u32, + #[comptime] op: UnaryElemwiseArgs, + #[comptime] config: &ElemwiseConfig, + ) { + let input = read::(inputs, outputs, &locals, write_pos, op.input, config); + let result = $func(input); + + write::(inputs, outputs, locals, write_pos, result, op.out, config); + } + }; +} + +#[cube] +fn assign( + inputs: &GlobalArgs, + outputs: &mut GlobalArgs, + locals: &mut LocalArgs, + write_pos: u32, + #[comptime] op: UnaryElemwiseArgs, + #[comptime] config: &ElemwiseConfig, +) { + let input = read::(inputs, outputs, locals, write_pos, op.input, config); + + write::(inputs, outputs, locals, write_pos, input, op.out, config); +} + +#[cube] +fn gather( + inputs: &GlobalArgs, + outputs: &mut GlobalArgs, + locals: &mut LocalArgs, + write_pos: u32, + #[comptime] dim: u32, + #[comptime] input: Arg, + #[comptime] indices: Arg, + #[comptime] output: Arg, + #[comptime] config: &ElemwiseConfig, +) { + let mut index = read::(inputs, outputs, locals, write_pos, indices, config); + let (pos, _precision) = comptime! { + match input { + Arg::Input(pos, precision, _) => (pos, precision), + _ => panic!("Input tensor isn't an input"), + } + }; + let line_size = match config.ref_layout { + Arg::Input(pos, _precision, _) => global_line_size(inputs, pos), + Arg::Output(pos, _precision, _) => global_line_size(outputs, pos), + _ => unreachable!(), + }; + let stride = global_stride(inputs, dim, pos); + + index *= Line::new(stride); + + if comptime![dim > 0] { + let index_before = global_offset( + inputs, + outputs, + write_pos, + comment!(input.clone()), + comptime![Some((0u32, dim))], + config, + ); + index += Line::new(index_before); + } + + if comptime![dim + 1 < config.rank] { + let index_after = global_offset( + inputs, + outputs, + write_pos, + input, + comptime![Some((dim + 1, config.rank))], + config, + ); + index += Line::new(index_after); + } + + let mut result = Line::empty(line_size); + + #[unroll] + for i in 0..line_size { + let index = index[i]; + + let input = read_input::(inputs, outputs, pos, index, LayoutInfo::IsRef, config, None); + result[i] = input[0]; + } + + write::(inputs, outputs, locals, write_pos, result, output, config); +} + +#[cube] +fn select_indices( + inputs: &GlobalArgs, + outputs: &mut GlobalArgs, + locals: &mut LocalArgs, + write_pos: u32, + #[comptime] dim: u32, + #[comptime] input: Arg, + #[comptime] indices: Arg, + #[comptime] output: Arg, + #[comptime] config: &ElemwiseConfig, +) { + let (line_size_ref, stride_dim_ref, shape_dim_ref) = match config.ref_layout { + Arg::Input(pos, _, _) => ( + global_line_size(inputs, pos), + global_stride(inputs, dim, pos), + global_shape(inputs, dim, pos), + ), + Arg::Output(pos, _, _) => ( + global_line_size(outputs, pos), + global_stride(outputs, dim, pos), + global_shape(outputs, dim, pos), + ), + _ => unreachable!(), + }; + + let pos_input = comptime! { + match input { + Arg::Input(pos, ..) => pos, + _ => panic!("Input tensor isn't an input"), + } + }; + let pos_indices = match indices { + Arg::Input(pos, ..) => pos, + _ => panic!("Indices tensor isn't an input"), + }; + + let stride_input_dim = global_stride(inputs, dim, pos_input); + + let mut index = 0u32; + let mut result = Line::empty(line_size_ref); + + if comptime![dim != config.rank - 1] { + // In this scenario the select is actually broadcasted along the axis we're working on. + // + // Therefore the same indices are used to fetch multiple entries in the input tensor. + + let write_pos_input = write_pos * line_size_ref; + let stride_input_line = global_stride(inputs, comptime![config.rank - 1], pos_input); + + if comptime![dim > 0] { + let index_before = global_offset( + inputs, + outputs, + write_pos_input, + comment!(input.clone()), + comptime![Some((0u32, dim))], + config, + ); + index += index_before; + } + + if comptime![dim + 1 < config.rank] { + let index_after = global_offset( + inputs, + outputs, + write_pos_input, + comment!(input.clone()), + comptime![Some((dim + 1, config.rank))], + config, + ); + index += index_after; + } + + let coordinate_dim = write_pos_input / stride_dim_ref % shape_dim_ref; + let offset_dim = read_input::( + inputs, + outputs, + pos_indices, + coordinate_dim, + LayoutInfo::IsRef, + config, + None, + ); + + index *= line_size_ref; + index += offset_dim[0] * stride_input_dim; + + #[unroll] + for i in 0..line_size_ref { + let input = read_input::( + inputs, + outputs, + pos_input, + index + i * stride_input_line, + LayoutInfo::IsRef, + config, + None, + ); + result[i] = input[0]; + } + } else { + // In this scenario the select is actually performed on the last dimension we're working on. + // + // Therefore we need to fetch multiple indices that correspond to different entries in the + // input tensor. + + if comptime![dim > 0] { + let index_before = global_offset( + inputs, + outputs, + write_pos, + comment!(input.clone()), + comptime![Some((0u32, dim))], + config, + ); + index += index_before; + } + + if comptime![dim + 1 < config.rank] { + let index_after = global_offset( + inputs, + outputs, + write_pos, + input, + comptime![Some((dim + 1, config.rank))], + config, + ); + index += index_after; + } + + let write_pos_indices = write_pos * line_size_ref; + + #[unroll] + for i in 0..line_size_ref { + let coordinate_dim = (write_pos_indices + i) / stride_dim_ref % shape_dim_ref; + let offset_dim = read_input::( + inputs, + outputs, + pos_indices, + coordinate_dim, + LayoutInfo::IsRef, + config, + None, + ); + + let input = read_input::( + inputs, + outputs, + pos_input, + index + (offset_dim[0] * stride_input_dim), + LayoutInfo::IsRef, + config, + None, + ); + result[i] = input[0]; + } + } + + write::(inputs, outputs, locals, write_pos, result, output, config); +} + +#[cube] +fn conditional_assign( + inputs: &GlobalArgs, + outputs: &mut GlobalArgs, + locals: &mut LocalArgs, + write_pos: u32, + #[comptime] cond: Arg, + #[comptime] lhs: Arg, + #[comptime] rhs: Arg, + #[comptime] out: Arg, + #[comptime] config: &ElemwiseConfig, +) { + let cond = read::(inputs, outputs, locals, write_pos, cond, config); + let lhs = read::(inputs, outputs, locals, write_pos, lhs, config); + let rhs = read::(inputs, outputs, locals, write_pos, rhs, config); + let result = select_many(cond, lhs, rhs); + + write::(inputs, outputs, locals, write_pos, result, out, config); +} + +binary_op!(add, +); +binary_op!(mul, *); +binary_op!(div, /); +binary_op!(sub, -); + +comparison_op!(equal, ==); +comparison_op!(greater, >); +comparison_op!(greater_equal, >=); +comparison_op!(lower, <); +comparison_op!(lower_equal, <=); + +binary_func!(powf, Line::::powf, Float); + +unary_func!(exp, Line::::exp, Float); +unary_func!(log, Line::::log, Float); +unary_func!(log1p, Line::::log1p, Float); +unary_func!(cos, Line::::cos, Float); +unary_func!(sin, Line::::sin, Float); +unary_func!(tanh, Line::::tanh, Float); +unary_func!(erf, Line::::erf, Float); +unary_func!(recip, Line::::recip, Float); +unary_func!(abs, Line::::abs, Numeric); diff --git a/crates/burn-cubecl/src/fusion/on_write/mod.rs b/crates/burn-cubecl-fusion/src/on_write/mod.rs similarity index 68% rename from crates/burn-cubecl/src/fusion/on_write/mod.rs rename to crates/burn-cubecl-fusion/src/on_write/mod.rs index 69bbc724d1..c6657f31d5 100644 --- a/crates/burn-cubecl/src/fusion/on_write/mod.rs +++ b/crates/burn-cubecl-fusion/src/on_write/mod.rs @@ -3,5 +3,9 @@ pub(crate) mod io; pub(crate) mod ir; pub(crate) mod kernel; pub(crate) mod settings; +pub(crate) mod tensor; + +mod base; +pub(crate) use base::*; pub mod trace; diff --git a/crates/burn-cubecl/src/fusion/on_write/settings.rs b/crates/burn-cubecl-fusion/src/on_write/settings.rs similarity index 100% rename from crates/burn-cubecl/src/fusion/on_write/settings.rs rename to crates/burn-cubecl-fusion/src/on_write/settings.rs diff --git a/crates/burn-cubecl-fusion/src/on_write/tensor.rs b/crates/burn-cubecl-fusion/src/on_write/tensor.rs new file mode 100644 index 0000000000..6e6a6dcb48 --- /dev/null +++ b/crates/burn-cubecl-fusion/src/on_write/tensor.rs @@ -0,0 +1,288 @@ +use std::hash::Hash; + +use cubecl::{ + ir::{Elem, ExpandElement, FloatKind, IntKind, Item, UIntKind}, + prelude::*, + unexpanded, +}; +use serde::{Deserialize, Serialize}; + +use super::DYN_ELEM_ID; + +#[derive(CubeType)] +pub struct GlobalTensor { + pub tensor: Tensor>>, + #[cube(comptime)] + pub elem: Elem, +} + +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)] +pub struct GlobalTensorCompilationArg { + tensor: TensorCompilationArg, + elem: Elem, +} + +#[derive(new, Debug)] +pub struct GlobalTensorArg<'a, R: Runtime> { + pub tensor: >> as LaunchArg>::RuntimeArg<'a, R>, + pub elem: Elem, +} + +#[derive(CubeType)] +pub enum GlobalScalar { + F32(f32), + F16(half::f16), + BF16(half::bf16), + I64(i64), + I32(i32), + I16(i16), + I8(i8), + U64(u64), + U32(u32), + U16(u16), + U8(u8), +} + +impl GlobalScalar { + pub fn as_u32(&self) -> u32 { + unexpanded!() + } + + pub fn read(&self) -> C { + unexpanded!() + } +} + +impl GlobalScalarExpand { + pub fn __expand_as_u32_method(&self, _scope: &mut Scope) -> ExpandElementTyped { + match self { + GlobalScalarExpand::U32(val) => val.clone(), + _ => todo!(), + } + } + pub fn __expand_read_method( + &self, + scope: &mut Scope, + ) -> ExpandElementTyped { + let dtype = C::as_elem(scope); + + match self { + GlobalScalarExpand::U64(val) => { + if dtype == Elem::UInt(cubecl::ir::UIntKind::U64) { + let expand: ExpandElement = val.clone().into(); + ExpandElementTyped::from(expand.clone()) + } else { + C::__expand_cast_from(scope, val.clone()) + } + } + GlobalScalarExpand::U32(val) => { + if dtype == Elem::UInt(cubecl::ir::UIntKind::U32) { + let expand: ExpandElement = val.clone().into(); + ExpandElementTyped::from(expand.clone()) + } else { + C::__expand_cast_from(scope, val.clone()) + } + } + GlobalScalarExpand::U16(val) => { + if dtype == Elem::UInt(cubecl::ir::UIntKind::U16) { + let expand: ExpandElement = val.clone().into(); + ExpandElementTyped::from(expand.clone()) + } else { + C::__expand_cast_from(scope, val.clone()) + } + } + GlobalScalarExpand::F32(val) => { + if dtype == Elem::Float(cubecl::ir::FloatKind::F32) { + let expand: ExpandElement = val.clone().into(); + ExpandElementTyped::from(expand.clone()) + } else { + C::__expand_cast_from(scope, val.clone()) + } + } + GlobalScalarExpand::F16(val) => { + if dtype == Elem::Float(cubecl::ir::FloatKind::F16) { + let expand: ExpandElement = val.clone().into(); + ExpandElementTyped::from(expand.clone()) + } else { + C::__expand_cast_from(scope, val.clone()) + } + } + GlobalScalarExpand::BF16(val) => { + if dtype == Elem::Float(cubecl::ir::FloatKind::BF16) { + let expand: ExpandElement = val.clone().into(); + ExpandElementTyped::from(expand.clone()) + } else { + C::__expand_cast_from(scope, val.clone()) + } + } + GlobalScalarExpand::U8(val) => { + if dtype == Elem::UInt(cubecl::ir::UIntKind::U8) { + let expand: ExpandElement = val.clone().into(); + ExpandElementTyped::from(expand.clone()) + } else { + C::__expand_cast_from(scope, val.clone()) + } + } + + GlobalScalarExpand::I64(val) => { + if dtype == Elem::Int(cubecl::ir::IntKind::I64) { + let expand: ExpandElement = val.clone().into(); + ExpandElementTyped::from(expand.clone()) + } else { + C::__expand_cast_from(scope, val.clone()) + } + } + GlobalScalarExpand::I32(val) => { + if dtype == Elem::Int(cubecl::ir::IntKind::I32) { + let expand: ExpandElement = val.clone().into(); + ExpandElementTyped::from(expand.clone()) + } else { + C::__expand_cast_from(scope, val.clone()) + } + } + GlobalScalarExpand::I16(val) => { + if dtype == Elem::Int(cubecl::ir::IntKind::I16) { + let expand: ExpandElement = val.clone().into(); + ExpandElementTyped::from(expand.clone()) + } else { + C::__expand_cast_from(scope, val.clone()) + } + } + GlobalScalarExpand::I8(val) => { + if dtype == Elem::Int(cubecl::ir::IntKind::I8) { + let expand: ExpandElement = val.clone().into(); + ExpandElementTyped::from(expand.clone()) + } else { + C::__expand_cast_from(scope, val.clone()) + } + } + } + } +} + +impl LaunchArg for GlobalScalar { + type RuntimeArg<'a, R: Runtime> = GlobalScalar; + + fn compilation_arg(arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg { + match arg { + GlobalScalar::F32(_) => GlobalScalarCompilationArg::new(Elem::Float(FloatKind::F32)), + GlobalScalar::F16(_) => GlobalScalarCompilationArg::new(Elem::Float(FloatKind::F16)), + GlobalScalar::BF16(_) => GlobalScalarCompilationArg::new(Elem::Float(FloatKind::BF16)), + GlobalScalar::I64(_) => GlobalScalarCompilationArg::new(Elem::Int(IntKind::I64)), + GlobalScalar::I32(_) => GlobalScalarCompilationArg::new(Elem::Int(IntKind::I32)), + GlobalScalar::I16(_) => GlobalScalarCompilationArg::new(Elem::Int(IntKind::I16)), + GlobalScalar::I8(_) => GlobalScalarCompilationArg::new(Elem::Int(IntKind::I8)), + GlobalScalar::U64(_) => GlobalScalarCompilationArg::new(Elem::UInt(UIntKind::U64)), + GlobalScalar::U32(_) => GlobalScalarCompilationArg::new(Elem::UInt(UIntKind::U32)), + GlobalScalar::U16(_) => GlobalScalarCompilationArg::new(Elem::UInt(UIntKind::U16)), + GlobalScalar::U8(_) => GlobalScalarCompilationArg::new(Elem::UInt(UIntKind::U8)), + } + } +} + +#[derive(new, Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)] +pub struct GlobalScalarCompilationArg { + elem: Elem, +} + +impl CompilationArg for GlobalScalarCompilationArg {} + +impl LaunchArgExpand for GlobalScalar { + type CompilationArg = GlobalScalarCompilationArg; + + fn expand( + arg: &Self::CompilationArg, + builder: &mut KernelBuilder, + ) -> ::ExpandType { + let expand = builder.scalar(arg.elem); + match arg.elem { + Elem::Float(float_kind) | Elem::AtomicFloat(float_kind) => match float_kind { + FloatKind::F16 => GlobalScalarExpand::F16(expand.into()), + FloatKind::BF16 => GlobalScalarExpand::BF16(expand.into()), + FloatKind::Flex32 => GlobalScalarExpand::F32(expand.into()), + FloatKind::F32 => GlobalScalarExpand::F32(expand.into()), + FloatKind::TF32 => GlobalScalarExpand::F32(expand.into()), + FloatKind::F64 => GlobalScalarExpand::F32(expand.into()), + }, + Elem::Int(int_kind) | Elem::AtomicInt(int_kind) => match int_kind { + IntKind::I8 => GlobalScalarExpand::I8(expand.into()), + IntKind::I16 => GlobalScalarExpand::I16(expand.into()), + IntKind::I32 => GlobalScalarExpand::I32(expand.into()), + IntKind::I64 => GlobalScalarExpand::I64(expand.into()), + }, + Elem::UInt(uint_kind) | Elem::AtomicUInt(uint_kind) => match uint_kind { + UIntKind::U8 => GlobalScalarExpand::U8(expand.into()), + UIntKind::U16 => GlobalScalarExpand::U16(expand.into()), + UIntKind::U32 => GlobalScalarExpand::U32(expand.into()), + UIntKind::U64 => GlobalScalarExpand::U64(expand.into()), + }, + Elem::Bool => panic!("Bool should be converted first."), + } + } +} + +impl ArgSettings for GlobalScalar { + fn register(&self, launcher: &mut KernelLauncher) { + match self { + GlobalScalar::F32(val) => launcher.register_f32(*val), + GlobalScalar::F16(val) => launcher.register_f16(*val), + GlobalScalar::BF16(val) => launcher.register_bf16(*val), + GlobalScalar::I64(val) => launcher.register_i64(*val), + GlobalScalar::I32(val) => launcher.register_i32(*val), + GlobalScalar::I16(val) => launcher.register_i16(*val), + GlobalScalar::I8(val) => launcher.register_i8(*val), + GlobalScalar::U64(val) => launcher.register_u64(*val), + GlobalScalar::U32(val) => launcher.register_u32(*val), + GlobalScalar::U16(val) => launcher.register_u16(*val), + GlobalScalar::U8(val) => launcher.register_u8(*val), + } + } +} + +impl LaunchArg for GlobalTensor { + type RuntimeArg<'a, R: Runtime> = GlobalTensorArg<'a, R>; + + fn compilation_arg(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg { + let tensor = >> as LaunchArg>::compilation_arg( + &runtime_arg.tensor, + ); + GlobalTensorCompilationArg { + tensor, + elem: runtime_arg.elem, + } + } +} + +impl ArgSettings for GlobalTensorArg<'_, R> { + fn register(&self, launcher: &mut KernelLauncher) { + launcher.register_tensor(&self.tensor) + } +} + +impl CompilationArg for GlobalTensorCompilationArg {} + +impl LaunchArgExpand for GlobalTensor { + type CompilationArg = GlobalTensorCompilationArg; + + fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> GlobalTensorExpand { + let tensor = builder.input_tensor(Item::vectorized(arg.elem, arg.tensor.vectorisation)); + + GlobalTensorExpand { + tensor: tensor.into(), + elem: arg.elem, + } + } + fn expand_output( + arg: &Self::CompilationArg, + builder: &mut KernelBuilder, + ) -> GlobalTensorExpand { + let tensor = match arg.tensor.inplace { + Some(id) => builder.inplace_output(id), + None => builder.output_tensor(Item::vectorized(arg.elem, arg.tensor.vectorisation)), + }; + GlobalTensorExpand { + tensor: tensor.into(), + elem: arg.elem, + } + } +} diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/base.rs b/crates/burn-cubecl-fusion/src/on_write/trace/base.rs similarity index 81% rename from crates/burn-cubecl/src/fusion/on_write/trace/base.rs rename to crates/burn-cubecl-fusion/src/on_write/trace/base.rs index 80ce63db23..6a105d8455 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/base.rs +++ b/crates/burn-cubecl-fusion/src/on_write/trace/base.rs @@ -1,4 +1,4 @@ -use crate::{fusion::CubeFusionHandle, BoolElement, CubeRuntime}; +use crate::CubeFusionHandle; use super::{ super::{ @@ -23,7 +23,7 @@ pub struct FuseOnWriteTrace { pub outputs: RegisteredTensors, pub inputs: RegisteredTensors, pub settings: FuseSettings, - pub scalars: BTreeMap, + pub scalars: Vec<(ElemwisePrecision, u32)>, pub views: Vec, pub indexed: BTreeSet, pub shape_ref: Vec, @@ -48,7 +48,7 @@ pub enum TensorView { impl FuseOnWriteTrace { /// Run a trace with the given [runner](TraceRunner). - pub fn run>( + pub fn run>( &self, client: &ComputeClient, device: &R::Device, @@ -83,7 +83,7 @@ impl FuseOnWriteTrace { } } - fn rollback( + fn rollback( &self, context: &mut Context<'_, CubeFusionHandle>, handle_inputs: Vec>, @@ -107,11 +107,12 @@ impl FuseOnWriteTrace { #[derive(Default, Clone, Serialize, Deserialize, Debug)] pub struct RegisteredTensors { - tensors: BTreeMap>, + tensors: BTreeMap>, + count: u32, } impl RegisteredTensors { - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.tensors.iter().flat_map(|(precision, descriptions)| { descriptions.iter().map(|desc| (*precision, desc)) }) @@ -121,45 +122,50 @@ impl RegisteredTensors { self.tensors.values().map(|v| v.len()).sum() } - pub fn get_index(&self, precision: ElemwisePrecision, tensor_id: TensorId) -> Option { + pub fn get_index(&self, precision: ElemwisePrecision, tensor_id: TensorId) -> Option { self.tensors.get(&precision).and_then(|items| { items .iter() - .enumerate() - .find(|(_pos, tensor)| tensor.id == tensor_id) - .map(|(pos, _)| pos) + .find(|(_id, tensor)| tensor.id == tensor_id) + .map(|(id, _)| *id) }) } - pub fn get_all(&self, precision: ElemwisePrecision) -> &[TensorIr] { + pub fn get_all(&self, precision: ElemwisePrecision) -> &[(u32, TensorIr)] { self.tensors .get(&precision) .map(|v| v.as_slice()) .unwrap_or(&[]) } - pub fn get(&self, precision: ElemwisePrecision, tensor_id: TensorId) -> Option<&TensorIr> { + pub fn get( + &self, + precision: ElemwisePrecision, + tensor_id: TensorId, + ) -> Option<&(u32, TensorIr)> { self.get_all(precision) .iter() - .find(|desc| desc.id == tensor_id) + .find(|(_, desc)| desc.id == tensor_id) } pub fn insert(&mut self, precision: ElemwisePrecision, tensor: TensorIr) -> u32 { + let pos = self.count; + self.count += 1; + if let Some(tensors) = self.tensors.get_mut(&precision) { - let position = tensors.len() as u32; - tensors.push(tensor); - position + tensors.push((pos, tensor)); } else { - self.tensors.insert(precision, vec![tensor]); - 0 - } + self.tensors.insert(precision, vec![(pos, tensor)]); + }; + + pos } pub fn update(&mut self, precision: ElemwisePrecision, tensor: &TensorIr) { if let Some(tensors) = self.tensors.get_mut(&precision) { - if let Some(tensor_old) = tensors + if let Some((_, tensor_old)) = tensors .iter_mut() - .find(|tensor_old| tensor_old.id == tensor.id) + .find(|(_, tensor_old)| tensor_old.id == tensor.id) { tensor_old.status = tensor.status.clone(); } diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/builder.rs b/crates/burn-cubecl-fusion/src/on_write/trace/builder.rs similarity index 97% rename from crates/burn-cubecl/src/fusion/on_write/trace/builder.rs rename to crates/burn-cubecl-fusion/src/on_write/trace/builder.rs index 56101019d0..b06935ba3f 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/builder.rs +++ b/crates/burn-cubecl-fusion/src/on_write/trace/builder.rs @@ -14,7 +14,7 @@ pub struct FuseOnWriteTraceBuilder { outputs: RegisteredTensors, settings: FuseSettings, inputs: RegisteredTensors, - scalars: BTreeMap, + scalars: Vec<(ElemwisePrecision, u32)>, views: Vec, indexed: BTreeMap, ops: Vec, @@ -31,7 +31,7 @@ impl FuseOnWriteTraceBuilder { outputs: RegisteredTensors::default(), settings, inputs: RegisteredTensors::default(), - scalars: BTreeMap::default(), + scalars: Vec::default(), views: Vec::new(), indexed: BTreeMap::new(), ops: Vec::new(), @@ -210,7 +210,7 @@ impl FuseOnWriteTraceBuilder { match self.inputs.get_index(precision_input, tensor.id) { Some(index) => { self.inputs.update(precision_input, tensor); - index as u32 + index } None => { return None; @@ -269,7 +269,7 @@ impl FuseOnWriteTraceBuilder { match self.inputs.get_index(precision_input, tensor.id) { Some(index) => { self.inputs.update(precision_input, tensor); - index as u32 + index } None => { return None; @@ -325,11 +325,9 @@ impl FuseOnWriteTraceBuilder { ElemwisePrecision::Bool => self.bool_precision, _ => precision, }; - let new_index = self.scalars.get(&precision).copied().unwrap_or(0); + let new_index = self.scalars.len() as u32; - let num_scalars = new_index + 1; - - self.scalars.insert(precision, num_scalars); + self.scalars.push((precision, new_index)); Arg::Scalar(new_index, precision) } @@ -343,7 +341,7 @@ impl FuseOnWriteTraceBuilder { let mut writes = BTreeMap::new(); - for (precision, tensor) in outputs.iter() { + for (precision, (_, tensor)) in outputs.iter() { let local = self.locals.get_any_precision(tensor.id).unwrap(); let out_index = outputs.get_index(precision, tensor.id).unwrap(); @@ -351,7 +349,7 @@ impl FuseOnWriteTraceBuilder { tensor.id, ElemwiseOp::Assign(UnaryElemwiseArgs { input: local, - out: Arg::Output(out_index as u32, precision, LayoutInfo::Unknown), + out: Arg::Output(out_index, precision, LayoutInfo::Unknown), }), ); } @@ -574,14 +572,14 @@ impl FuseOnWriteTraceBuilder { if !is_read { let (tensor_id, precision) = entry; - let tensor = self.outputs.get(precision, tensor_id).unwrap(); + let (_, tensor) = self.outputs.get(precision, tensor_id).unwrap(); result.insert(precision, tensor.clone()); } } // All tensors where their latest representation is read only should be written to since they // are going to be used after the fused kernel by other operations. - for (precision, tensor) in self.outputs.iter() { + for (precision, (_, tensor)) in self.outputs.iter() { if let TensorStatus::ReadOnly = tensor.status { result.insert(precision, tensor.clone()); } diff --git a/crates/burn-cubecl-fusion/src/on_write/trace/executor.rs b/crates/burn-cubecl-fusion/src/on_write/trace/executor.rs new file mode 100644 index 0000000000..482b53bf40 --- /dev/null +++ b/crates/burn-cubecl-fusion/src/on_write/trace/executor.rs @@ -0,0 +1,251 @@ +use std::marker::PhantomData; + +use burn_fusion::stream::Context; +use burn_tensor::DType; +use cubecl::{ + client::ComputeClient, + prelude::{Sequence, TensorArg}, + CubeElement, Runtime, +}; + +use super::{HandleInput, HandleOutput, LaunchPlan, TensorView, TraceRunner}; +use crate::{ + elem_dtype, + on_write::{ + ir::{ElemwiseConfig, ElemwiseOp, ElemwisePrecision, GlobalArgsLaunch}, + tensor::{GlobalScalar, GlobalTensorArg}, + }, + CubeFusionHandle, +}; + +/// Execute a [plan](LaunchPlan) using a [runner](TraceRunner) modifying the [context](Context). +pub struct LaunchPlanExecutor<'a, R: Runtime> { + scalars: &'a Vec<(ElemwisePrecision, u32)>, + views: &'a Vec, + ops: &'a Vec, + _r: PhantomData, +} + +#[derive(new)] +pub struct ExecutionError> { + pub runner_error: Runner::Error, + pub handles_input: Vec>, + pub handles_output: Vec>, +} + +impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { + pub fn new( + scalars: &'a Vec<(ElemwisePrecision, u32)>, + views: &'a Vec, + ops: &'a Vec, + ) -> Self { + Self { + scalars, + views, + ops, + _r: PhantomData, + } + } + + pub fn execute, BT: CubeElement>( + self, + client: &ComputeClient, + runner: &Runner, + context: &mut Context<'_, CubeFusionHandle>, + plan: LaunchPlan<'a, R>, + ) -> Result<(), ExecutionError> { + let reference = match plan.reference { + Some(reference) => reference, + None => { + if plan.writes.is_empty() { + // Nothing to write, can skip execution. + return Ok(()); + } else { + panic!("An output should exist for the fused kernel") + } + } + }; + + let inputs = self.register_inputs(context, &plan.handle_inputs); + let outputs = self.register_outputs::(&plan.handle_outputs); + + let mut ops = Sequence::::new(); + + for read_ops in plan.reads.into_values() { + for op in read_ops { + ops.push(op); + } + } + + for op in self.ops.iter() { + ops.push(op.clone()); + } + + for op in plan.writes.into_values() { + ops.push(op); + } + + let config = ElemwiseConfig { + rank: plan.rank as u32, + ref_layout: reference.layout, + ops, + }; + + Runner::run(runner, client, inputs, outputs, &config) + .map_err(|err| ExecutionError::new(err, plan.handle_inputs, plan.handle_outputs)) + } + + fn register_inputs<'h>( + &self, + context: &mut Context<'_, CubeFusionHandle>, + handle_inputs: &'h [HandleInput], + ) -> GlobalArgsLaunch<'h, R> { + let mut inputs = GlobalArgsLaunch::default(); + + for hi in handle_inputs.iter() { + let arg = hi.handle.as_tensor_arg(&hi.global_shape, hi.vectorization); + inputs + .tensors + .push(GlobalTensorArg::new(arg, hi.precision.into_elem())); + } + + let mut index_f32 = 0; + let mut index_f16 = 0; + let mut index_bf16 = 0; + let mut index_u64 = 0; + let mut index_u32 = 0; + let mut index_u16 = 0; + let mut index_u8 = 0; + let mut index_i64 = 0; + let mut index_i32 = 0; + let mut index_i16 = 0; + let mut index_i8 = 0; + + for (precision, _pos) in self.scalars.iter() { + match precision { + ElemwisePrecision::F32 => { + inputs + .scalars + .push(GlobalScalar::F32(context.scalar_f32[index_f32])); + index_f32 += 1; + } + ElemwisePrecision::F16 => { + inputs + .scalars + .push(GlobalScalar::F16(context.scalar_f16[index_f16])); + index_f16 += 1; + } + ElemwisePrecision::BF16 => { + inputs + .scalars + .push(GlobalScalar::BF16(context.scalar_bf16[index_bf16])); + index_bf16 += 1; + } + ElemwisePrecision::I64 => { + inputs + .scalars + .push(GlobalScalar::I64(context.scalar_i64[index_i64])); + index_i64 += 1; + } + ElemwisePrecision::I32 => { + inputs + .scalars + .push(GlobalScalar::I32(context.scalar_i32[index_i32])); + index_i32 += 1; + } + ElemwisePrecision::I16 => { + inputs + .scalars + .push(GlobalScalar::I16(context.scalar_i16[index_i16])); + index_i16 += 1; + } + ElemwisePrecision::I8 => { + inputs + .scalars + .push(GlobalScalar::I8(context.scalar_i8[index_i8])); + index_i8 += 1; + } + ElemwisePrecision::U64 => { + inputs + .scalars + .push(GlobalScalar::U64(context.scalar_u64[index_u64])); + index_u64 += 1; + } + ElemwisePrecision::U32 => { + inputs + .scalars + .push(GlobalScalar::U32(context.scalar_u32[index_u32])); + index_u32 += 1; + } + ElemwisePrecision::U16 => { + inputs + .scalars + .push(GlobalScalar::U16(context.scalar_u16[index_u16])); + index_u16 += 1; + } + ElemwisePrecision::U8 => { + inputs + .scalars + .push(GlobalScalar::U8(context.scalar_u8[index_u8])); + index_u8 += 1; + } + ElemwisePrecision::Bool => todo!(), + } + } + + // Reshape values are pushed in reverse in the same scalar buffer for all `u32` + for relative in self.views.iter().rev() { + if let TensorView::Reshape { reshaped, .. } = relative { + let global = context.tensors.get(reshaped).unwrap(); + + for shape in global.shape.iter().rev() { + inputs.scalars.push(GlobalScalar::U32(*shape as u32)); + } + } + } + + inputs + } + + fn register_outputs<'s, BT: CubeElement>( + &self, + handle_outputs: &'s [HandleOutput], + ) -> GlobalArgsLaunch<'s, R> { + let mut outputs = GlobalArgsLaunch::default(); + + for item in handle_outputs.iter() { + match item { + HandleOutput::Alias { + input_pos, + precision, + } => { + outputs.tensors.push(GlobalTensorArg::new( + TensorArg::alias(*input_pos), + precision.into_elem(), + )); + } + HandleOutput::Owned { + precision, + handle, + global_shape, + vectorization, + .. + } => { + let arg = handle.as_tensor_arg(global_shape, *vectorization); + + let elem = match precision { + ElemwisePrecision::Bool => match elem_dtype::() { + DType::U32 => ElemwisePrecision::U32.into_elem(), + DType::U8 => ElemwisePrecision::U8.into_elem(), + _ => todo!(), + }, + _ => precision.into_elem(), + }; + outputs.tensors.push(GlobalTensorArg::new(arg, elem)); + } + } + } + + outputs + } +} diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/input.rs b/crates/burn-cubecl-fusion/src/on_write/trace/input.rs similarity index 77% rename from crates/burn-cubecl/src/fusion/on_write/trace/input.rs rename to crates/burn-cubecl-fusion/src/on_write/trace/input.rs index 0a3b4aca2f..f7ee9c004b 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/input.rs +++ b/crates/burn-cubecl-fusion/src/on_write/trace/input.rs @@ -1,17 +1,15 @@ use super::TensorView; -use crate::{ - fusion::{on_write::settings::FuseSettings, CubeFusionHandle}, - CubeRuntime, -}; +use crate::{on_write::settings::FuseSettings, CubeFusionHandle}; use burn_fusion::stream::Context; use burn_ir::{TensorId, TensorStatus}; +use cubecl::Runtime; use std::marker::PhantomData; use super::{HandleInput, LaunchPlan, PotentialInplace, RegisteredTensors}; /// Fetch and register [input handles](HandleInput) and itendify potential inputs that /// can be used inplace. -pub struct InputPlanner<'a, R: CubeRuntime> { +pub struct InputPlanner<'a, R: Runtime> { inputs: &'a RegisteredTensors, inputs_unhandled: &'a Vec, views: &'a Vec, @@ -20,7 +18,7 @@ pub struct InputPlanner<'a, R: CubeRuntime> { _r: PhantomData, } -impl<'a, R: CubeRuntime> InputPlanner<'a, R> { +impl<'a, R: Runtime> InputPlanner<'a, R> { pub fn new( inputs: &'a RegisteredTensors, inputs_unhandled: &'a Vec, @@ -39,7 +37,15 @@ impl<'a, R: CubeRuntime> InputPlanner<'a, R> { } pub fn run(self, context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>) { - for (i, (precision, tensor_relative)) in self.inputs.iter().enumerate() { + let mut handles = Vec::with_capacity(self.inputs.len()); + let mut globals = Vec::with_capacity(self.inputs.len()); + + for _ in 0..self.inputs.len() { + handles.push(None); + globals.push(None); + } + + for (precision, (pos, tensor_relative)) in self.inputs.iter() { let mut tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); // Important to take the status of the relative graph and not // the global graph, since the status of the global graph @@ -62,7 +68,7 @@ impl<'a, R: CubeRuntime> InputPlanner<'a, R> { && self.shape_ref == &tensor_relative.shape { plan.potential_inplaces.push(PotentialInplace { - input_pos: i, + input_pos: *pos as usize, tensor_relative, strides: handle.strides.clone(), }); @@ -76,7 +82,7 @@ impl<'a, R: CubeRuntime> InputPlanner<'a, R> { } } - plan.handle_inputs.push(HandleInput { + handles[*pos as usize] = Some(HandleInput { precision, handle, relative_id: tensor_relative.id, @@ -84,7 +90,12 @@ impl<'a, R: CubeRuntime> InputPlanner<'a, R> { global_shape: tensor_global.shape.clone(), vectorization: 1, }); - plan.global_inputs.push(tensor_global); + globals[*pos as usize] = Some(tensor_global); + } + + for (handle, global) in handles.into_iter().zip(globals.into_iter()) { + plan.handle_inputs.push(handle.unwrap()); + plan.global_inputs.push(global.unwrap()); } } } diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/mod.rs b/crates/burn-cubecl-fusion/src/on_write/trace/mod.rs similarity index 100% rename from crates/burn-cubecl/src/fusion/on_write/trace/mod.rs rename to crates/burn-cubecl-fusion/src/on_write/trace/mod.rs diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/output.rs b/crates/burn-cubecl-fusion/src/on_write/trace/output.rs similarity index 81% rename from crates/burn-cubecl/src/fusion/on_write/trace/output.rs rename to crates/burn-cubecl-fusion/src/on_write/trace/output.rs index 03bc5cff1c..ca6d22c6aa 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/output.rs +++ b/crates/burn-cubecl-fusion/src/on_write/trace/output.rs @@ -1,33 +1,28 @@ use burn_fusion::stream::Context; use burn_ir::{TensorId, TensorIr}; use burn_tensor::DType; -use cubecl::{client::ComputeClient, ir::Elem}; +use cubecl::{client::ComputeClient, ir::Elem, CubeElement, Runtime}; use crate::{ - fusion::{ - on_write::ir::{Arg, ElemwiseOp, LayoutInfo}, - strides_dyn_rank, CubeFusionHandle, - }, - tensor::is_contiguous, - BoolElement, CubeRuntime, + elem_dtype, is_contiguous, + on_write::ir::{Arg, ElemwiseOp, LayoutInfo}, + strides_dyn_rank, CubeFusionHandle, }; use super::{ super::ir::ElemwisePrecision, HandleOutput, LaunchPlan, Reference, RegisteredTensors, TensorView, }; -use std::collections::BTreeMap; /// Create or reuse handles for the outputs. /// /// It is also responsible to select the reference tensor. -pub struct OutputPlanner<'a, R: CubeRuntime> { +pub struct OutputPlanner<'a, R: Runtime> { inputs: &'a RegisteredTensors, views: &'a Vec, outputs_sorted: Vec>, handles: Vec>>, globals: Vec>, - mapper: OutputPositionMapper, } struct OutputSorted<'a> { @@ -38,27 +33,25 @@ struct OutputSorted<'a> { enum OutputKind { Normal, - Inplace { input_pos: usize }, + Inplace { + /// The position in the potential inplace vector + input_pos: usize, + }, Transform(TensorView), } -impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { +impl<'a, R: Runtime> OutputPlanner<'a, R> { pub fn new( inputs: &'a RegisteredTensors, outputs: &'a RegisteredTensors, views: &'a Vec, ) -> Self { - let mut mapper = OutputPositionMapper::default(); let mut outputs_sorted: Vec<_> = outputs .iter() - .enumerate() - .map(|(pos, (precision, tensor))| { - mapper.register(precision, pos); - OutputSorted { - pos_original: pos, - precision, - tensor_relative: tensor, - } + .map(|(precision, (pos, tensor))| OutputSorted { + pos_original: *pos as usize, + precision, + tensor_relative: tensor, }) .collect(); @@ -83,11 +76,10 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { views, handles, globals, - mapper, } } - pub fn run( + pub fn run( mut self, client: &ComputeClient, device: &R::Device, @@ -156,11 +148,11 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { Self::add_layout_info_inputs(plan); } - fn add_layout_info_inputs(analysis: &mut LaunchPlan<'_, R>) { - for hi in analysis.handle_inputs.iter() { - if let Some(reference) = analysis.reference.as_ref() { + fn add_layout_info_inputs(plan: &mut LaunchPlan<'_, R>) { + for hi in plan.handle_inputs.iter() { + if let Some(reference) = plan.reference.as_ref() { if reference.strides == hi.handle.strides && reference.shape == hi.global_shape { - if let Some(ops) = analysis.reads.get_mut(&hi.relative_id) { + if let Some(ops) = plan.reads.get_mut(&hi.relative_id) { for op in ops.iter_mut() { if let ElemwiseOp::Assign(op) = op { op.input.add_layout_info(LayoutInfo::SameAsRef); @@ -194,8 +186,7 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { && pi.tensor_relative.shape == output.tensor_relative.shape && pi.strides == strides }) - .map(|(pos, _)| pos) - .map(|input_pos| OutputKind::Inplace { input_pos }) + .map(|(pos, _)| OutputKind::Inplace { input_pos: pos }) .unwrap_or(OutputKind::Normal) } @@ -217,7 +208,7 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { .unwrap(); plan.reference = Some(Reference { - layout: Arg::Input(index_input as u32, output.precision, LayoutInfo::IsRef), + layout: Arg::Input(index_input, output.precision, LayoutInfo::IsRef), shape: tensor_global.shape.clone(), strides: handle_input.handle.strides.clone(), }); @@ -247,7 +238,7 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { } #[allow(clippy::too_many_arguments)] - fn normal_output( + fn normal_output( &mut self, client: &ComputeClient, device: &R::Device, @@ -258,11 +249,12 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { strides: Vec, ) { if plan.reference.is_none() { - let position = self - .mapper - .resolve_index(&output.precision, output.pos_original); plan.reference = Some(Reference { - layout: Arg::Output(position, output.precision, LayoutInfo::IsRef), + layout: Arg::Output( + output.pos_original as u32, + output.precision, + LayoutInfo::IsRef, + ), shape: tensor_global.shape.clone(), strides: strides.clone(), }); @@ -283,7 +275,7 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { // We encode bool tensors as `B`. let dtype = match tensor_global.dtype { - DType::Bool => BT::dtype(), + DType::Bool => elem_dtype::(), _ => tensor_global.dtype, }; let size = tensor_global.shape.iter().product::() * Elem::from(dtype).size(); @@ -312,7 +304,7 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { } #[allow(clippy::too_many_arguments)] - fn reshaped_output( + fn reshaped_output( &mut self, client: &ComputeClient, device: &R::Device, @@ -332,7 +324,7 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { // We encode bool tensors as `B`. let dtype = match tensor_global.dtype { - DType::Bool => BT::dtype(), + DType::Bool => elem_dtype::(), _ => tensor_global.dtype, }; @@ -352,6 +344,7 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { context .handles .register_handle(tensor_global.id, handle.clone()); + // IT will never be access, just a way to keep the original position working. self.handles[output.pos_original] = Some(HandleOutput::Alias { input_pos: pos_input, @@ -372,7 +365,7 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { } #[allow(clippy::too_many_arguments)] - fn swapped_dims_output( + fn swapped_dims_output( &mut self, client: &ComputeClient, device: &R::Device, @@ -392,7 +385,7 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { // We encode bool tensors as `B`. let dtype = match tensor_global.dtype { - DType::Bool => BT::dtype(), + DType::Bool => elem_dtype::(), _ => tensor_global.dtype, }; @@ -420,32 +413,3 @@ impl<'a, R: CubeRuntime> OutputPlanner<'a, R> { self.globals[output.pos_original] = Some(tensor_global); } } - -/// Group output position by [element precision](ElemwisePrecision). -#[derive(Default, Debug)] -pub struct OutputPositionMapper { - map: BTreeMap>, -} - -impl OutputPositionMapper { - /// Register a new output with the given precision and position. - pub fn register(&mut self, precision: ElemwisePrecision, pos_handle: usize) { - if let Some(positions) = self.map.get_mut(&precision) { - positions.push(pos_handle); - } else { - self.map.insert(precision, vec![pos_handle]); - } - } - - /// Returns the right position from the precision and the global position in all outputs. - pub fn resolve_index(&mut self, precision: &ElemwisePrecision, pos_handle: usize) -> u32 { - self.map - .get(precision) - .unwrap() - .iter() - .enumerate() - .find(|(_pos_elem, pos_all)| **pos_all == pos_handle) - .map(|(pos_elem, _pos_all)| pos_elem) - .unwrap() as u32 - } -} diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/plan.rs b/crates/burn-cubecl-fusion/src/on_write/trace/plan.rs similarity index 86% rename from crates/burn-cubecl/src/fusion/on_write/trace/plan.rs rename to crates/burn-cubecl-fusion/src/on_write/trace/plan.rs index 5a25730963..f1f1620632 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/plan.rs +++ b/crates/burn-cubecl-fusion/src/on_write/trace/plan.rs @@ -1,18 +1,16 @@ use std::collections::BTreeMap; use crate::{ - fusion::{ - on_write::ir::{Arg, ElemwiseOp, ElemwisePrecision}, - CubeFusionHandle, - }, - CubeRuntime, + on_write::ir::{Arg, ElemwiseOp, ElemwisePrecision}, + CubeFusionHandle, }; use burn_ir::{TensorId, TensorIr}; +use cubecl::Runtime; /// The plan is responsible to keep runtime information related to the launch of a fused kernel /// at one place. #[derive(Debug)] -pub(crate) struct LaunchPlan<'a, R: CubeRuntime> { +pub(crate) struct LaunchPlan<'a, R: Runtime> { pub potential_inplaces: Vec>, pub global_inputs: Vec, pub global_outputs: Vec, @@ -25,7 +23,7 @@ pub(crate) struct LaunchPlan<'a, R: CubeRuntime> { pub rank: usize, } -impl LaunchPlan<'_, R> { +impl LaunchPlan<'_, R> { pub fn new( reads: &BTreeMap>, writes: &BTreeMap, @@ -47,7 +45,7 @@ impl LaunchPlan<'_, R> { } #[derive(Debug)] -pub enum HandleOutput { +pub enum HandleOutput { Alias { input_pos: usize, precision: ElemwisePrecision, @@ -62,7 +60,7 @@ pub enum HandleOutput { } #[derive(Debug)] -pub struct HandleInput { +pub struct HandleInput { pub relative_id: TensorId, pub global_id: TensorId, pub precision: ElemwisePrecision, diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/runner.rs b/crates/burn-cubecl-fusion/src/on_write/trace/runner.rs similarity index 98% rename from crates/burn-cubecl/src/fusion/on_write/trace/runner.rs rename to crates/burn-cubecl-fusion/src/on_write/trace/runner.rs index 1bb67a3541..2fc370f517 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/runner.rs +++ b/crates/burn-cubecl-fusion/src/on_write/trace/runner.rs @@ -1,5 +1,5 @@ use super::super::ir::{ElemwiseConfig, GlobalArgsLaunch}; -use crate::{fusion::CubeFusionHandle, CubeRuntime}; +use crate::CubeFusionHandle; use burn_ir::{TensorId, TensorIr}; use cubecl::prelude::*; use std::collections::BTreeMap; @@ -7,7 +7,7 @@ use std::collections::BTreeMap; /// A trace runner is responsible for determining the vectorization factor as well as launching /// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch) /// with a provided [element wise config](ElemwiseConfig). -pub trait TraceRunner { +pub trait TraceRunner { /// The error that might happen while running the trace. type Error; @@ -40,7 +40,7 @@ pub trait TraceRunner { } } -fn vectorization_default<'a, R: CubeRuntime>( +fn vectorization_default<'a, R: Runtime>( vectorizations: &mut BTreeMap, handles_inputs: impl Iterator>, inputs: impl Iterator, diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/vectorization.rs b/crates/burn-cubecl-fusion/src/on_write/trace/vectorization.rs similarity index 94% rename from crates/burn-cubecl/src/fusion/on_write/trace/vectorization.rs rename to crates/burn-cubecl-fusion/src/on_write/trace/vectorization.rs index 82361ff8ee..5eeb8e34dc 100644 --- a/crates/burn-cubecl/src/fusion/on_write/trace/vectorization.rs +++ b/crates/burn-cubecl-fusion/src/on_write/trace/vectorization.rs @@ -5,19 +5,17 @@ use std::{ use burn_fusion::stream::Context; use burn_ir::TensorId; +use cubecl::Runtime; use crate::{ - fusion::{ - on_write::{ir::ElemwiseOp, settings::FuseSettings}, - CubeFusionHandle, - }, - CubeRuntime, + on_write::{ir::ElemwiseOp, settings::FuseSettings}, + CubeFusionHandle, }; use super::{HandleOutput, LaunchPlan, TensorView, TraceRunner}; /// Select the best vectorization factor for each tensor handle. -pub struct VectorizationPlanner<'a, R: CubeRuntime> { +pub struct VectorizationPlanner<'a, R: Runtime> { settings: &'a FuseSettings, views: &'a Vec, reads: &'a BTreeMap>, @@ -25,7 +23,7 @@ pub struct VectorizationPlanner<'a, R: CubeRuntime> { _r: PhantomData, } -impl<'a, R: CubeRuntime> VectorizationPlanner<'a, R> { +impl<'a, R: Runtime> VectorizationPlanner<'a, R> { pub fn new( views: &'a Vec, reads: &'a BTreeMap>, diff --git a/crates/burn-cubecl/src/fusion/tune.rs b/crates/burn-cubecl-fusion/src/tune.rs similarity index 87% rename from crates/burn-cubecl/src/fusion/tune.rs rename to crates/burn-cubecl-fusion/src/tune.rs index c6c9cbab4a..aa5794ec45 100644 --- a/crates/burn-cubecl/src/fusion/tune.rs +++ b/crates/burn-cubecl-fusion/src/tune.rs @@ -1,13 +1,13 @@ use super::CubeFusionHandle; -use crate::CubeRuntime; use burn_fusion::stream::{Context, ContextOwned}; +use cubecl::Runtime; /// Fusion context used when tuning kernels. /// /// Either the original context is returned or a fork of the original. /// The fork is only given when performing autotuning, and not when actually performing the /// operation. -pub enum TuneContext<'a, R: CubeRuntime> { +pub enum TuneContext<'a, R: Runtime> { Original(&'a mut Context<'a, CubeFusionHandle>), Fork(Box>>), } @@ -18,7 +18,7 @@ pub enum TuneContext<'a, R: CubeRuntime> { /// /// This should only be used with the [tuner](cubecl::tune::LocalTuner), since safety assumptions /// are made based on its behavior. -pub struct TuneInput { +pub struct TuneInput { context: UnsafeTuneContext, optimization: *const O, } @@ -33,15 +33,15 @@ pub struct TuneInput { /// [cubecl::tune::LocalTuner::execute] function. This is the case, since autotune functions are /// tuned using a cloned version of the input; therefore, a fork of the context will be used to find /// the best kernel to use, which can be async. -enum UnsafeTuneContext { +enum UnsafeTuneContext { Original(*mut Context<'static, CubeFusionHandle>), Fork(Box>>), } -unsafe impl Send for UnsafeTuneContext {} -unsafe impl Send for TuneInput {} +unsafe impl Send for UnsafeTuneContext {} +unsafe impl Send for TuneInput {} -impl TuneInput { +impl TuneInput { /// Create a new autotune input from the [context](Context) and an optimization. pub fn new(context: &mut Context>, optimization: &O) -> Self { let context = UnsafeTuneContext::new(context); @@ -65,7 +65,7 @@ impl TuneInput { } } -impl UnsafeTuneContext { +impl UnsafeTuneContext { fn new(context: &mut Context<'_, CubeFusionHandle>) -> Self { let ptr = core::ptr::from_mut(context); @@ -84,7 +84,7 @@ impl UnsafeTuneContext { } } -impl Clone for TuneInput { +impl Clone for TuneInput { fn clone(&self) -> Self { Self { context: self.context.clone(), @@ -93,7 +93,7 @@ impl Clone for TuneInput { } } -impl Clone for UnsafeTuneContext { +impl Clone for UnsafeTuneContext { fn clone(&self) -> Self { let context = match self { UnsafeTuneContext::Original(ptr) => { diff --git a/crates/burn-cubecl/Cargo.toml b/crates/burn-cubecl/Cargo.toml index c866665230..5b7dc60821 100644 --- a/crates/burn-cubecl/Cargo.toml +++ b/crates/burn-cubecl/Cargo.toml @@ -24,7 +24,7 @@ export_tests = [ "fusion", "paste", ] -fusion = ["burn-fusion"] +fusion = ["burn-fusion", "burn-cubecl-fusion"] fusion-experimental = ["fusion"] std = ["cubecl/std", "burn-tensor/std"] @@ -33,6 +33,7 @@ template = [] [dependencies] burn-common = { path = "../burn-common", version = "0.17.0" } burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } +burn-cubecl-fusion = { path = "../burn-cubecl-fusion", version = "0.17.0", optional = true } burn-ir = { path = "../burn-ir", version = "0.17.0", default-features = false } burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "cubecl", diff --git a/crates/burn-cubecl/README.md b/crates/burn-cubecl/README.md index 036a0f41d9..d1811f39e6 100644 --- a/crates/burn-cubecl/README.md +++ b/crates/burn-cubecl/README.md @@ -1,3 +1,3 @@ -# Burn JIT Backend +# Burn CubeCL Backend Generic backend that can be compiled just-in-time (JIT) to any shader language target. diff --git a/crates/burn-cubecl/src/fusion/base.rs b/crates/burn-cubecl/src/fusion.rs similarity index 54% rename from crates/burn-cubecl/src/fusion/base.rs rename to crates/burn-cubecl/src/fusion.rs index 3d7a0b5221..0b161c89a7 100644 --- a/crates/burn-cubecl/src/fusion/base.rs +++ b/crates/burn-cubecl/src/fusion.rs @@ -1,39 +1,20 @@ -use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState}; -use super::matmul::optimization::{MatmulOptimization, MatmulOptimizationState}; -use crate::fusion::elemwise::builder::ElementWiseBuilder; -use crate::fusion::matmul::builder::MatmulBuilder; use crate::BoolElement; use crate::{kernel, tensor::CubeTensor, CubeBackend, CubeRuntime, FloatElement, IntElement}; +use burn_cubecl_fusion::elemwise::optimization::ElemwiseOptimization; +use burn_cubecl_fusion::matmul::builder::MatmulBuilder; +use burn_cubecl_fusion::matmul::optimization::MatmulOptimization; +use burn_cubecl_fusion::matmul::MatmulFallbackFn; +use burn_cubecl_fusion::CubeFusionHandle; +use burn_cubecl_fusion::{ + elemwise::builder::ElementWiseBuilder, CubeOptimization, CubeOptimizationState, +}; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_ir::{BackendIr, TensorHandle}; -use burn_tensor::{DType, Shape}; +use burn_tensor::Shape; use core::marker::PhantomData; -use cubecl::client::ComputeClient; -use cubecl::prelude::{TensorArg, TensorHandleRef}; use half::{bf16, f16}; -use serde::{Deserialize, Serialize}; - -/// Fusion optimization type for JIT. -/// -/// More optimization variants should be added here. -pub enum CubeOptimization { - /// Element wise optimization. - ElementWise(ElemwiseOptimization), - /// Matrix multiplication optimization. - Matmul(MatmulOptimization), -} - -/// Fusion optimization state type for JIT. -/// -/// More optimization variants should be added here. -#[derive(Serialize, Deserialize)] -pub enum CubeOptimizationState { - /// Element wise state. - ElementWise(ElemwiseOptimizationState), - /// Matrix multiplication optimization state. - Matmul(MatmulOptimizationState), -} +use std::sync::Arc; impl burn_fusion::Optimization> for CubeOptimization where @@ -66,34 +47,87 @@ where CubeOptimizationState::ElementWise(state) => { Self::ElementWise(ElemwiseOptimization::from_state(device, state)) } - CubeOptimizationState::Matmul(state) => { - Self::Matmul(MatmulOptimization::from_state(device, state)) - } + CubeOptimizationState::Matmul(state) => Self::Matmul(MatmulOptimization::from_state( + device, + state, + Arc::new(FallbackMatmul), + )), } } } +struct FallbackMatmul; + +impl MatmulFallbackFn for FallbackMatmul { + fn run( + &self, + lhs: (CubeFusionHandle, &[usize]), + rhs: (CubeFusionHandle, &[usize]), + ) -> CubeFusionHandle { + match lhs.0.dtype { + burn_tensor::DType::F64 => run_fallback_matmul::(lhs, rhs), + burn_tensor::DType::F32 => run_fallback_matmul::(lhs, rhs), + burn_tensor::DType::F16 => run_fallback_matmul::(lhs, rhs), + burn_tensor::DType::BF16 => run_fallback_matmul::(lhs, rhs), + _ => todo!("Not yet supported"), + } + } +} + +fn run_fallback_matmul( + lhs: (CubeFusionHandle, &[usize]), + rhs: (CubeFusionHandle, &[usize]), +) -> CubeFusionHandle { + let lhs_tensor = into_tensor( + lhs.0, + Shape { + dims: lhs.1.to_vec(), + }, + ); + let rhs_tensor = into_tensor( + rhs.0, + Shape { + dims: rhs.1.to_vec(), + }, + ); + let out_tensor = crate::kernel::matmul::matmul::( + lhs_tensor, + rhs_tensor, + None, + crate::kernel::matmul::MatmulStrategy::default(), + ) + .unwrap(); + + CubeFusionHandle { + client: out_tensor.client, + handle: out_tensor.handle, + device: out_tensor.device, + dtype: out_tensor.dtype, + strides: out_tensor.strides, + } +} + impl BackendIr for CubeBackend { type Handle = CubeFusionHandle; fn float_tensor(handle: TensorHandle) -> burn_tensor::ops::FloatTensor { - handle.handle.into_tensor(handle.shape) + into_tensor(handle.handle, handle.shape) } fn int_tensor(handle: TensorHandle) -> burn_tensor::ops::IntTensor { - handle.handle.into_tensor(handle.shape) + into_tensor(handle.handle, handle.shape) } fn bool_tensor(handle: TensorHandle) -> burn_tensor::ops::BoolTensor { - handle.handle.into_tensor(handle.shape) + into_tensor(handle.handle, handle.shape) } fn quantized_tensor( handle: TensorHandle, ) -> burn_tensor::ops::QuantizedTensor { - handle.handle.into_tensor(handle.shape) + into_tensor(handle.handle, handle.shape) } fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor) -> Self::Handle { @@ -132,6 +166,7 @@ impl FusionRuntime for FusionCubeRuntime Box::new(MatmulBuilder::::new( device.clone(), BT::as_elem_native_unchecked().into(), + Arc::new(FallbackMatmul), )), ] } @@ -170,89 +205,14 @@ impl FusionBack } } -pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Vec { - let mut strides = vec![0; shape.len()]; - - let mut current = 1; - shape.iter().enumerate().rev().for_each(|(index, val)| { - strides[index] = current; - current *= val; - }); - - strides -} - -/// Handle to be used when fusing operations. -pub struct CubeFusionHandle { - /// Compute client for jit. - pub client: ComputeClient, - /// The buffer where the data are stored. - pub handle: cubecl::server::Handle, - /// The device of the current tensor. - pub device: R::Device, - pub(crate) dtype: DType, - pub(crate) strides: Vec, -} - -impl core::fmt::Debug for CubeFusionHandle { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!( - "CubeFusionHandle {{ device: {:?}, runtime: {}}}", - self.device, - R::name(), - )) - } -} - -impl Clone for CubeFusionHandle { - fn clone(&self) -> Self { - Self { - client: self.client.clone(), - handle: self.handle.clone(), - device: self.device.clone(), - strides: self.strides.clone(), - dtype: self.dtype, - } - } -} - -unsafe impl Send for CubeFusionHandle {} -unsafe impl Sync for CubeFusionHandle {} - -impl CubeFusionHandle { - pub(crate) fn into_tensor(self, shape: Shape) -> CubeTensor { - CubeTensor { - client: self.client, - handle: self.handle, - device: self.device, - shape, - strides: self.strides, - dtype: self.dtype, - } - } - /// Return the reference to a tensor handle. - pub fn as_handle_ref<'a>(&'a self, shape: &'a [usize]) -> TensorHandleRef<'a, R> { - TensorHandleRef { - handle: &self.handle, - strides: &self.strides, - shape, - runtime: PhantomData, - elem_size: self.dtype.size(), - } - } - /// Return the reference to a tensor argument. - pub fn as_tensor_arg<'a>(&'a self, shape: &'a [usize], vectorisation: u8) -> TensorArg<'a, R> { - let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape); - - unsafe { - TensorArg::from_raw_parts_and_size( - handle.handle, - handle.strides, - handle.shape, - vectorisation, - self.dtype.size(), - ) - } +fn into_tensor(handle: CubeFusionHandle, shape: Shape) -> CubeTensor { + CubeTensor { + client: handle.client, + handle: handle.handle, + device: handle.device, + shape, + strides: handle.strides, + dtype: handle.dtype, } } diff --git a/crates/burn-cubecl/src/fusion/elemwise/mod.rs b/crates/burn-cubecl/src/fusion/elemwise/mod.rs deleted file mode 100644 index bdbeb986d3..0000000000 --- a/crates/burn-cubecl/src/fusion/elemwise/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub(crate) mod builder; -pub(crate) mod optimization; diff --git a/crates/burn-cubecl/src/fusion/elemwise/optimization.rs b/crates/burn-cubecl/src/fusion/elemwise/optimization.rs deleted file mode 100644 index 6578ad4ff2..0000000000 --- a/crates/burn-cubecl/src/fusion/elemwise/optimization.rs +++ /dev/null @@ -1,178 +0,0 @@ -use crate::{fusion::on_write::kernel::fuse_on_write, BoolElement}; -use crate::{fusion::CubeFusionHandle, CubeRuntime}; -use burn_fusion::stream::Context; -use cubecl::{calculate_cube_count_elemwise, client::ComputeClient, prelude::*, CubeDim}; -use serde::{Deserialize, Serialize}; - -use crate::fusion::on_write::{ - ir::{Arg, ElemwiseConfig, ElemwisePrecision, GlobalArgs, GlobalArgsLaunch}, - trace::{FuseOnWriteTrace, TraceRunner}, -}; - -#[derive(new)] -/// Fuse element wise operations into a single kernel. -pub struct ElemwiseOptimization { - trace: FuseOnWriteTrace, - client: ComputeClient, - device: R::Device, - len: usize, -} - -#[derive(Serialize, Deserialize)] -/// State for the [elemwise optimization](ElemwiseOptimization). -pub struct ElemwiseOptimizationState { - trace: FuseOnWriteTrace, - len: usize, -} - -impl ElemwiseOptimization { - /// Execute the optimization. - pub fn execute(&mut self, context: &mut Context<'_, CubeFusionHandle>) { - self.trace - .run::(&self.client, &self.device, context, &ElemwiseRunner) - .unwrap(); - } - - /// Number of element wise operations fused. - pub fn num_ops_fused(&self) -> usize { - self.len - } - - /// Create an optimization from its [state](ElemwiseOptimizationState). - pub fn from_state(device: &R::Device, state: ElemwiseOptimizationState) -> Self { - Self { - trace: state.trace, - len: state.len, - client: R::client(device), - device: device.clone(), - } - } - - /// Convert the optimization to its [state](ElemwiseOptimizationState). - pub fn to_state(&self) -> ElemwiseOptimizationState { - ElemwiseOptimizationState { - trace: self.trace.clone(), - len: self.len, - } - } -} - -pub struct ElemwiseRunner; - -impl TraceRunner for ElemwiseRunner { - type Error = (); // No error possible - - fn run<'a>( - &'a self, - client: &'a ComputeClient, - inputs: GlobalArgsLaunch<'a, R>, - outputs: GlobalArgsLaunch<'a, R>, - config: &'a ElemwiseConfig, - ) -> Result<(), Self::Error> { - let arg = match config.ref_layout { - Arg::Input(index, precision, _) => match precision { - ElemwisePrecision::F32 => inputs.t_f32.values.get(index as usize), - ElemwisePrecision::F16 => inputs.t_f16.values.get(index as usize), - ElemwisePrecision::BF16 => inputs.t_bf16.values.get(index as usize), - ElemwisePrecision::U64 => inputs.t_u64.values.get(index as usize), - ElemwisePrecision::U32 => inputs.t_u32.values.get(index as usize), - ElemwisePrecision::U16 => inputs.t_u16.values.get(index as usize), - ElemwisePrecision::U8 => inputs.t_u8.values.get(index as usize), - ElemwisePrecision::I64 => inputs.t_i64.values.get(index as usize), - ElemwisePrecision::I32 => inputs.t_i32.values.get(index as usize), - ElemwisePrecision::I16 => inputs.t_i16.values.get(index as usize), - ElemwisePrecision::I8 => inputs.t_i8.values.get(index as usize), - _ => panic!("Invalid value"), - }, - Arg::Output(index, precision, _) => match precision { - ElemwisePrecision::F32 => outputs.t_f32.values.get(index as usize), - ElemwisePrecision::F16 => outputs.t_f16.values.get(index as usize), - ElemwisePrecision::BF16 => outputs.t_bf16.values.get(index as usize), - ElemwisePrecision::U64 => outputs.t_u64.values.get(index as usize), - ElemwisePrecision::U32 => outputs.t_u32.values.get(index as usize), - ElemwisePrecision::U16 => outputs.t_u16.values.get(index as usize), - ElemwisePrecision::U8 => outputs.t_u8.values.get(index as usize), - ElemwisePrecision::I64 => outputs.t_i64.values.get(index as usize), - ElemwisePrecision::I32 => outputs.t_i32.values.get(index as usize), - ElemwisePrecision::I16 => outputs.t_i16.values.get(index as usize), - ElemwisePrecision::I8 => outputs.t_i8.values.get(index as usize), - _ => panic!("Invalid value"), - }, - _ => panic!("Invalid value"), - }; - let (shape, vectorization) = match arg { - Some(val) => match val { - TensorArg::Handle { - handle, - vectorization_factor, - } => (handle.shape, vectorization_factor), - _ => panic!("Can't be an alias"), - }, - None => panic!("Invalid argument"), - }; - let total_elem = shape.iter().product::() / *vectorization as usize; - let cube_dim = CubeDim::default(); - let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); - - unsafe { - elemwise_fuse::launch_unchecked( - client, - cube_count, - cube_dim, - inputs, - outputs, - config.clone(), - ); - }; - - Ok(()) - } -} - -#[cube(launch_unchecked)] -fn elemwise_fuse( - inputs: &GlobalArgs, - outputs: &mut GlobalArgs, - #[comptime] config: &ElemwiseConfig, -) { - // We write no values for this fusion. - let values = Registry::>::new(); - let args = comptime![Sequence::::new()]; - let pos = ABSOLUTE_POS; - - let length = match comptime![config.ref_layout.clone()] { - Arg::Input(index, precision, _) => match comptime![precision] { - ElemwisePrecision::F32 => inputs.t_f32.index(index).len(), - ElemwisePrecision::F16 => inputs.t_f16.index(index).len(), - ElemwisePrecision::BF16 => inputs.t_bf16.index(index).len(), - ElemwisePrecision::U64 => inputs.t_u64.index(index).len(), - ElemwisePrecision::U32 => inputs.t_u32.index(index).len(), - ElemwisePrecision::U16 => inputs.t_u16.index(index).len(), - ElemwisePrecision::U8 => inputs.t_u8.index(index).len(), - ElemwisePrecision::I64 => inputs.t_i64.index(index).len(), - ElemwisePrecision::I32 => inputs.t_i32.index(index).len(), - ElemwisePrecision::I16 => inputs.t_i16.index(index).len(), - ElemwisePrecision::I8 => inputs.t_i8.index(index).len(), - _ => comptime![panic!("Unsupported precision {precision:?}")], - }, - Arg::Output(index, precision, _) => match comptime![precision] { - ElemwisePrecision::F32 => outputs.t_f32.index(index).len(), - ElemwisePrecision::F16 => outputs.t_f16.index(index).len(), - ElemwisePrecision::BF16 => outputs.t_bf16.index(index).len(), - ElemwisePrecision::U64 => outputs.t_u64.index(index).len(), - ElemwisePrecision::U32 => outputs.t_u32.index(index).len(), - ElemwisePrecision::U16 => outputs.t_u16.index(index).len(), - ElemwisePrecision::U8 => outputs.t_u8.index(index).len(), - ElemwisePrecision::I64 => outputs.t_i64.index(index).len(), - ElemwisePrecision::I32 => outputs.t_i32.index(index).len(), - ElemwisePrecision::I16 => outputs.t_i16.index(index).len(), - ElemwisePrecision::I8 => outputs.t_i8.index(index).len(), - _ => comptime![panic!("Unsupported precision {precision:?}")], - }, - _ => comptime![panic!("Invalid ref layout.")], - }; - - if pos < length { - fuse_on_write::(inputs, outputs, pos, values, args, config) - } -} diff --git a/crates/burn-cubecl/src/fusion/matmul/mod.rs b/crates/burn-cubecl/src/fusion/matmul/mod.rs deleted file mode 100644 index cddec5983a..0000000000 --- a/crates/burn-cubecl/src/fusion/matmul/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub(crate) mod args; -pub(crate) mod builder; -pub(crate) mod optimization; -pub(crate) mod spec; -pub(crate) mod tune; diff --git a/crates/burn-cubecl/src/fusion/on_write/io.rs b/crates/burn-cubecl/src/fusion/on_write/io.rs deleted file mode 100644 index 5ecdf76410..0000000000 --- a/crates/burn-cubecl/src/fusion/on_write/io.rs +++ /dev/null @@ -1,1155 +0,0 @@ -use super::ir::*; -use cubecl::{ - ir::{ExpandElement, Variable}, - prelude::*, - unexpanded, -}; -use serde::{Deserialize, Serialize}; - -#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] -pub enum Transform { - Reshape(Sequence), - SwapDim(u32, u32), -} - -#[cube] -/// Read the value from the [arg](Arg) and cast it to the generic cube primitive. -pub fn read( - inputs: &GlobalArgs, - outputs: &GlobalArgs, - locals: &LocalArgs, - ref_pos: u32, - #[comptime] arg: Arg, - #[comptime] config: &ElemwiseConfig, -) -> Line { - match arg { - Arg::Input(pos, precision, layout) => read_input( - inputs, outputs, pos, ref_pos, layout, precision, config, None, - ), - Arg::Output(pos, precision, layout) => { - read_output(inputs, outputs, pos, ref_pos, layout, precision, config) - } - Arg::Local(pos, precision) => match comptime![precision] { - ElemwisePrecision::F32 => Line::cast_from(locals.l_f32.find(pos)), - ElemwisePrecision::F16 => Line::cast_from(locals.l_f16.find(pos)), - ElemwisePrecision::BF16 => Line::cast_from(locals.l_bf16.find(pos)), - ElemwisePrecision::U64 => Line::cast_from(locals.l_u64.find(pos)), - ElemwisePrecision::U32 => Line::cast_from(locals.l_u32.find(pos)), - ElemwisePrecision::U16 => Line::cast_from(locals.l_u16.find(pos)), - ElemwisePrecision::U8 => Line::cast_from(locals.l_u8.find(pos)), - ElemwisePrecision::I64 => Line::cast_from(locals.l_i64.find(pos)), - ElemwisePrecision::I32 => Line::cast_from(locals.l_i32.find(pos)), - ElemwisePrecision::I16 => Line::cast_from(locals.l_i16.find(pos)), - ElemwisePrecision::I8 => Line::cast_from(locals.l_i8.find(pos)), - ElemwisePrecision::Bool => Line::cast_from(locals.l_bool.find(pos)), - }, - Arg::Scalar(..) => { - let scalar = read_scalar::(inputs, arg); - Line::new(scalar) - } - Arg::ScalarShape(_) => { - let scalar = read_scalar_shape(inputs, arg); - Line::cast_from(scalar) - } - Arg::Literal(val, _precision) => Line::new(from_const_int::(val)), - Arg::InputReshaped { - original, shape, .. - } => match comptime![original.as_ref().clone()] { - Arg::Input(pos, precision, layout) => read_input( - inputs, - outputs, - pos, - ref_pos, - layout, - precision, - config, - comptime![Some(Transform::Reshape(shape))], - ), - _ => comptime![panic!("Only input can be reshaped")], - }, - Arg::InputSwapDims { original, dims, .. } => match comptime![original.as_ref().clone()] { - Arg::Input(pos, precision, layout) => read_input( - inputs, - outputs, - pos, - ref_pos, - layout, - precision, - config, - comptime![Some(Transform::SwapDim(dims.0, dims.1))], - ), - _ => comptime![panic!("Only input can be reshaped")], - }, - } -} - -#[cube] -pub fn read_scalar(inputs: &GlobalArgs, #[comptime] arg: Arg) -> C { - match arg { - Arg::Scalar(pos, precision) => match comptime![precision] { - ElemwisePrecision::F32 => C::cast_from(*inputs.s_f32.index(pos)), - ElemwisePrecision::F16 => C::cast_from(*inputs.s_f16.index(pos)), - ElemwisePrecision::BF16 => C::cast_from(*inputs.s_bf16.index(pos)), - ElemwisePrecision::U64 => C::cast_from(*inputs.s_u64.index(pos)), - ElemwisePrecision::U32 => C::cast_from(*inputs.s_u32.index(pos)), - ElemwisePrecision::U16 => C::cast_from(*inputs.s_u16.index(pos)), - ElemwisePrecision::U8 => C::cast_from(*inputs.s_u8.index(pos)), - ElemwisePrecision::I64 => C::cast_from(*inputs.s_i64.index(pos)), - ElemwisePrecision::I32 => C::cast_from(*inputs.s_i32.index(pos)), - ElemwisePrecision::I16 => C::cast_from(*inputs.s_i16.index(pos)), - ElemwisePrecision::I8 => C::cast_from(*inputs.s_i8.index(pos)), - _ => comptime![panic!("Unsupported precision {precision:?}")], - }, - _ => comptime![panic!("Not a scalar")], - } -} - -#[cube] -pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: Arg) -> u32 { - match arg { - Arg::ScalarShape(pos) => { - let offset = comptime![inputs.s_u32.len() - pos - 1]; - *inputs.s_u32.index(offset) - } - _ => comptime![panic!("Not a scalar shape")], - } -} - -#[cube] -pub fn read_input( - inputs: &GlobalArgs, - outputs: &GlobalArgs, - #[comptime] pos: u32, - ref_pos: u32, - #[comptime] layout: LayoutInfo, - #[comptime] precision: ElemwisePrecision, - #[comptime] config: &ElemwiseConfig, - #[comptime] transform: Option, -) -> Line { - match comptime![precision] { - ElemwisePrecision::F32 => { - let tensor = inputs.t_f32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::F16 => { - let tensor = inputs.t_f16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::BF16 => { - let tensor = inputs.t_bf16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U64 => { - let tensor = inputs.t_u64.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U32 => { - let tensor = inputs.t_u32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U16 => { - let tensor = inputs.t_u16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U8 => { - let tensor = inputs.t_u8.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I64 => { - let tensor = inputs.t_i64.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I32 => { - let tensor = inputs.t_i32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I16 => { - let tensor = inputs.t_i16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I8 => { - let tensor = inputs.t_i8.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, transform) - } - }; - Line::cast_from(tensor[offset]) - } - _ => comptime![panic!("Unsupported precision {precision:?}")], - } -} - -#[cube] -pub fn read_output( - inputs: &GlobalArgs, - outputs: &GlobalArgs, - pos: u32, - ref_pos: u32, - #[comptime] layout: LayoutInfo, - #[comptime] precision: ElemwisePrecision, - #[comptime] config: &ElemwiseConfig, -) -> Line { - match comptime![precision] { - ElemwisePrecision::F32 => { - let tensor = outputs.t_f32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::F16 => { - let tensor = outputs.t_f16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::BF16 => { - let tensor = outputs.t_bf16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U64 => { - let tensor = outputs.t_u64.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U32 => { - let tensor = outputs.t_u32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U16 => { - let tensor = outputs.t_u16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U8 => { - let tensor = outputs.t_u8.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I64 => { - let tensor = outputs.t_i64.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I32 => { - let tensor = outputs.t_i32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I16 => { - let tensor = outputs.t_i16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I8 => { - let tensor = outputs.t_i8.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - Line::cast_from(tensor[offset]) - } - _ => comptime![panic!("Unsupported precision {precision:?}")], - } -} - -#[cube] -/// Write the given value at the [arg](Arg) position. -pub fn write( - inputs: &GlobalArgs, - outputs: &mut GlobalArgs, - locals: &mut LocalArgs, - ref_pos: u32, - value: Line, - #[comptime] arg: Arg, - #[comptime] config: &ElemwiseConfig, -) { - match arg { - Arg::Output(pos, precision, layout) => match comptime![precision] { - ElemwisePrecision::F32 => { - let tensor = outputs.t_f32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - let tensor = outputs.t_f32.index_mut(pos); - tensor[offset] = Line::cast_from(value); - } - ElemwisePrecision::F16 => { - let tensor = outputs.t_f16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - let tensor = outputs.t_f16.index_mut(pos); - tensor[offset] = Line::cast_from(value); - } - ElemwisePrecision::BF16 => { - let tensor = outputs.t_bf16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - let tensor = outputs.t_bf16.index_mut(pos); - tensor[offset] = Line::cast_from(value); - } - ElemwisePrecision::U64 => { - let tensor = outputs.t_u64.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - let tensor = outputs.t_u64.index_mut(pos); - tensor[offset] = Line::cast_from(value); - } - ElemwisePrecision::U32 => { - let tensor = outputs.t_u32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - let tensor = outputs.t_u32.index_mut(pos); - tensor[offset] = Line::cast_from(value); - } - ElemwisePrecision::U16 => { - let tensor = outputs.t_u16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - let tensor = outputs.t_u16.index_mut(pos); - tensor[offset] = Line::cast_from(value); - } - ElemwisePrecision::U8 => { - let tensor = outputs.t_u8.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - let tensor = outputs.t_u8.index_mut(pos); - tensor[offset] = Line::cast_from(value); - } - ElemwisePrecision::I64 => { - let tensor = outputs.t_i64.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - let tensor = outputs.t_i64.index_mut(pos); - tensor[offset] = Line::cast_from(value); - } - ElemwisePrecision::I32 => { - let tensor = outputs.t_i32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - let tensor = outputs.t_i32.index_mut(pos); - tensor[offset] = Line::cast_from(value); - } - ElemwisePrecision::I16 => { - let tensor = outputs.t_i16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - let tensor = outputs.t_i16.index_mut(pos); - tensor[offset] = Line::cast_from(value); - } - ElemwisePrecision::I8 => { - let tensor = outputs.t_i8.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => { - get_offset(inputs, outputs, tensor, ref_pos, None, config, None) - } - }; - let tensor = outputs.t_i8.index_mut(pos); - tensor[offset] = Line::cast_from(value); - } - _ => comptime![panic!("Unsupported precision {precision:?}")], - }, - Arg::Local(pos, precision) => match comptime![precision] { - ElemwisePrecision::F32 => locals.l_f32.insert(pos, Line::cast_from(value)), - ElemwisePrecision::F16 => locals.l_f16.insert(pos, Line::cast_from(value)), - ElemwisePrecision::BF16 => locals.l_bf16.insert(pos, Line::cast_from(value)), - ElemwisePrecision::U64 => locals.l_u64.insert(pos, Line::cast_from(value)), - ElemwisePrecision::U32 => locals.l_u32.insert(pos, Line::cast_from(value)), - ElemwisePrecision::U16 => locals.l_u16.insert(pos, Line::cast_from(value)), - ElemwisePrecision::U8 => locals.l_u8.insert(pos, Line::cast_from(value)), - ElemwisePrecision::I64 => locals.l_i64.insert(pos, Line::cast_from(value)), - ElemwisePrecision::I32 => locals.l_i32.insert(pos, Line::cast_from(value)), - ElemwisePrecision::I16 => locals.l_i16.insert(pos, Line::cast_from(value)), - ElemwisePrecision::I8 => locals.l_i8.insert(pos, Line::cast_from(value)), - ElemwisePrecision::Bool => locals.l_bool.insert(pos, Line::cast_from(value)), - }, - _ => comptime![panic!("Can't write into inputs and scalars")], - } -} - -#[cube] -pub(crate) fn global_offset( - inputs: &GlobalArgs, - outputs: &GlobalArgs, - index: u32, - #[comptime] arg: Arg, - #[comptime] range: Option<(u32, u32)>, - #[comptime] config: &ElemwiseConfig, -) -> u32 { - match arg { - Arg::Input(pos, precision, _layout) => match precision { - ElemwisePrecision::F32 => get_offset( - inputs, - outputs, - inputs.t_f32.index(pos), - index, - range, - config, - None, - ), - ElemwisePrecision::F16 => get_offset( - inputs, - outputs, - inputs.t_f16.index(pos), - index, - range, - config, - None, - ), - ElemwisePrecision::BF16 => get_offset( - inputs, - outputs, - inputs.t_bf16.index(pos), - index, - range, - config, - None, - ), - ElemwisePrecision::I64 => get_offset( - inputs, - outputs, - inputs.t_i64.index(pos), - index, - range, - config, - None, - ), - ElemwisePrecision::I32 => get_offset( - inputs, - outputs, - inputs.t_i32.index(pos), - index, - range, - config, - None, - ), - ElemwisePrecision::I16 => get_offset( - inputs, - outputs, - inputs.t_i16.index(pos), - index, - range, - config, - None, - ), - ElemwisePrecision::I8 => get_offset( - inputs, - outputs, - inputs.t_i8.index(pos), - index, - range, - config, - None, - ), - ElemwisePrecision::U64 => get_offset( - inputs, - outputs, - inputs.t_u64.index(pos), - index, - range, - config, - None, - ), - ElemwisePrecision::U32 => get_offset( - inputs, - outputs, - inputs.t_u32.index(pos), - index, - range, - config, - None, - ), - ElemwisePrecision::U16 => get_offset( - inputs, - outputs, - inputs.t_u16.index(pos), - index, - range, - config, - None, - ), - ElemwisePrecision::U8 => get_offset( - inputs, - outputs, - inputs.t_u8.index(pos), - index, - range, - config, - None, - ), - ElemwisePrecision::Bool => comptime!(panic!( - "Should be resolved to the correct bool type used by the backend" - )), - }, - _ => todo!(), - } -} - -#[cube] -fn get_offset( - inputs: &GlobalArgs, - outputs: &GlobalArgs, - tensor: &Tensor>, - pos: u32, - #[comptime] range: Option<(u32, u32)>, - #[comptime] config: &ElemwiseConfig, - #[comptime] transform: Option, -) -> u32 { - match comptime![config.ref_layout.clone()] { - Arg::Input(index, precision, _) => match comptime![precision] { - ElemwisePrecision::F32 => { - let layout = inputs.t_f32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::F16 => { - let layout = inputs.t_f16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::BF16 => { - let layout = inputs.t_bf16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::U64 => { - let layout = inputs.t_u64.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::U32 => { - let layout = inputs.t_u32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::U16 => { - let layout = inputs.t_u16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::U8 => { - let layout = inputs.t_u8.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::I64 => { - let layout = inputs.t_i64.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::I32 => { - let layout = inputs.t_i32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::I16 => { - let layout = inputs.t_i16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::I8 => { - let layout = inputs.t_i8.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - _ => comptime![panic!("Unsupported precision {precision:?}")], - }, - Arg::Output(index, precision, _) => match comptime![precision] { - ElemwisePrecision::F32 => { - let layout = outputs.t_f32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::F16 => { - let layout = outputs.t_f16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::BF16 => { - let layout = outputs.t_bf16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::U64 => { - let layout = outputs.t_u64.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::U32 => { - let layout = outputs.t_u32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::U16 => { - let layout = outputs.t_u16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::U8 => { - let layout = outputs.t_u8.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::I64 => { - let layout = outputs.t_i64.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::I32 => { - let layout = outputs.t_i32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::I16 => { - let layout = outputs.t_i16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - ElemwisePrecision::I8 => { - let layout = outputs.t_i8.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, range, config.rank, transform) - } - _ => comptime![panic!("Unsupported precision {precision:?}")], - }, - _ => comptime![panic!("Invalid ref layout.")], - } -} - -#[cube] -pub fn global_line_size( - global: &GlobalArgs, - #[comptime] pos: u32, - #[comptime] precision: ElemwisePrecision, -) -> u32 { - u32::cast_from(match comptime![precision] { - ElemwisePrecision::F32 => { - let tensor = global.t_f32.index(pos); - tensor.line_size() - } - ElemwisePrecision::F16 => { - let tensor = global.t_f16.index(pos); - tensor.line_size() - } - ElemwisePrecision::BF16 => { - let tensor = global.t_bf16.index(pos); - tensor.line_size() - } - ElemwisePrecision::U64 => { - let tensor = global.t_u64.index(pos); - tensor.line_size() - } - ElemwisePrecision::U32 => { - let tensor = global.t_u32.index(pos); - tensor.line_size() - } - ElemwisePrecision::U16 => { - let tensor = global.t_u16.index(pos); - tensor.line_size() - } - ElemwisePrecision::U8 => { - let tensor = global.t_u8.index(pos); - tensor.line_size() - } - ElemwisePrecision::I64 => { - let tensor = global.t_i64.index(pos); - tensor.line_size() - } - ElemwisePrecision::I32 => { - let tensor = global.t_i32.index(pos); - tensor.line_size() - } - ElemwisePrecision::I16 => { - let tensor = global.t_i16.index(pos); - tensor.line_size() - } - ElemwisePrecision::I8 => { - let tensor = global.t_i8.index(pos); - tensor.line_size() - } - _ => comptime![panic!("Unsupported precision {precision:?}")], - }) -} - -#[cube] -pub fn global_rank( - global: &GlobalArgs, - #[comptime] pos: u32, - #[comptime] precision: ElemwisePrecision, -) -> u32 { - match comptime![precision] { - ElemwisePrecision::F32 => { - let tensor = global.t_f32.index(pos); - tensor.rank() - } - ElemwisePrecision::F16 => { - let tensor = global.t_f16.index(pos); - tensor.rank() - } - ElemwisePrecision::BF16 => { - let tensor = global.t_bf16.index(pos); - tensor.rank() - } - ElemwisePrecision::U64 => { - let tensor = global.t_u64.index(pos); - tensor.rank() - } - ElemwisePrecision::U32 => { - let tensor = global.t_u32.index(pos); - tensor.rank() - } - ElemwisePrecision::U16 => { - let tensor = global.t_u16.index(pos); - tensor.rank() - } - ElemwisePrecision::U8 => { - let tensor = global.t_u8.index(pos); - tensor.rank() - } - ElemwisePrecision::I64 => { - let tensor = global.t_i64.index(pos); - tensor.rank() - } - ElemwisePrecision::I32 => { - let tensor = global.t_i32.index(pos); - tensor.rank() - } - ElemwisePrecision::I16 => { - let tensor = global.t_i16.index(pos); - tensor.rank() - } - ElemwisePrecision::I8 => { - let tensor = global.t_i8.index(pos); - tensor.rank() - } - _ => comptime![panic!("Unsupported precision {precision:?}")], - } -} -#[cube] -pub fn global_shape( - global: &GlobalArgs, - dim: u32, - #[comptime] pos: u32, - #[comptime] precision: ElemwisePrecision, -) -> u32 { - match comptime![precision] { - ElemwisePrecision::F32 => { - let tensor = global.t_f32.index(pos); - tensor.shape(dim) - } - ElemwisePrecision::F16 => { - let tensor = global.t_f16.index(pos); - tensor.shape(dim) - } - ElemwisePrecision::BF16 => { - let tensor = global.t_bf16.index(pos); - tensor.shape(dim) - } - ElemwisePrecision::U64 => { - let tensor = global.t_u64.index(pos); - tensor.shape(dim) - } - ElemwisePrecision::U32 => { - let tensor = global.t_u32.index(pos); - tensor.shape(dim) - } - ElemwisePrecision::U16 => { - let tensor = global.t_u16.index(pos); - tensor.shape(dim) - } - ElemwisePrecision::U8 => { - let tensor = global.t_u8.index(pos); - tensor.shape(dim) - } - ElemwisePrecision::I64 => { - let tensor = global.t_i64.index(pos); - tensor.shape(dim) - } - ElemwisePrecision::I32 => { - let tensor = global.t_i32.index(pos); - tensor.shape(dim) - } - ElemwisePrecision::I16 => { - let tensor = global.t_i16.index(pos); - tensor.shape(dim) - } - ElemwisePrecision::I8 => { - let tensor = global.t_i8.index(pos); - tensor.shape(dim) - } - _ => comptime![panic!("Unsupported precision {precision:?}")], - } -} - -#[cube] -pub fn global_stride( - global: &GlobalArgs, - dim: u32, - #[comptime] pos: u32, - #[comptime] precision: ElemwisePrecision, -) -> u32 { - match comptime![precision] { - ElemwisePrecision::F32 => { - let tensor = global.t_f32.index(pos); - tensor.stride(dim) - } - ElemwisePrecision::F16 => { - let tensor = global.t_f16.index(pos); - tensor.stride(dim) - } - ElemwisePrecision::BF16 => { - let tensor = global.t_bf16.index(pos); - tensor.stride(dim) - } - ElemwisePrecision::U64 => { - let tensor = global.t_u64.index(pos); - tensor.stride(dim) - } - ElemwisePrecision::U32 => { - let tensor = global.t_u32.index(pos); - tensor.stride(dim) - } - ElemwisePrecision::U16 => { - let tensor = global.t_u16.index(pos); - tensor.stride(dim) - } - ElemwisePrecision::U8 => { - let tensor = global.t_u8.index(pos); - tensor.stride(dim) - } - ElemwisePrecision::I64 => { - let tensor = global.t_i64.index(pos); - tensor.stride(dim) - } - ElemwisePrecision::I32 => { - let tensor = global.t_i32.index(pos); - tensor.stride(dim) - } - ElemwisePrecision::I16 => { - let tensor = global.t_i16.index(pos); - tensor.stride(dim) - } - ElemwisePrecision::I8 => { - let tensor = global.t_i8.index(pos); - tensor.stride(dim) - } - _ => comptime![panic!("Unsupported precision {precision:?}")], - } -} - -/// Returns the offset of the tensor corresponding to the layout tensor. -#[cube] -fn index_offset_with_layout( - inputs: &GlobalArgs, - tensor: &Tensor>, - layout: &Tensor>, - index: u32, - #[comptime] range: Option<(u32, u32)>, - #[comptime] rank: u32, - #[comptime] transform: Option, -) -> u32 { - match comptime![transform.clone()] { - Some(Transform::Reshape(shape)) => { - comptime![assert!( - range.is_none(), - "Can't get a range on a reshaped tensor." - )]; - let index = reshaped_index(inputs, layout, index, rank, shape); - reshaped_index_to_original_index(tensor, index, rank) - } - Some(Transform::SwapDim(dim1, dim2)) => { - let (start, end) = comptime! {match range { - Some(range) => range, - None => (0u32, rank), - }}; - - let offset_ref = index * layout.line_size(); - let mut offset = 0u32; - - #[unroll] - for i in start..end { - let index = comptime![swap_dims_transform(&i, (dim1, dim2))]; - let ogwl = offset_ref / layout.stride(i); - offset += ogwl % tensor.shape(index) * tensor.stride(index); - } - - offset / tensor.line_size() - } - None => { - let (start, end) = comptime! {match range { - Some(range) => range, - None => (0u32, rank), - }}; - - let offset_ref = index * layout.line_size(); - let mut offset = 0u32; - - for i in start..end { - let ogwl = offset_ref / layout.stride(i); - offset += ogwl % tensor.shape(i) * tensor.stride(i); - } - - offset / tensor.line_size() - } - } -} - -fn swap_dims_transform(i: &I, dims: (u32, u32)) -> u32 { - let i_cloned: I = i.clone(); - let i = i_cloned.value().as_const().unwrap().as_u32(); - - if i == dims.0 { - dims.1 - } else if i == dims.1 { - dims.0 - } else { - i - } -} - -#[cube] -fn reshaped_index( - inputs: &GlobalArgs, - layout: &Tensor>, - index: u32, - #[comptime] rank: u32, - #[comptime] shape: Sequence, -) -> u32 { - let index = index * layout.line_size(); - - let mut offset = 0u32; - let mut stride_curr = 1u32; - - #[unroll] - for r in 0..rank { - let i = comptime![reverse_index(rank, r)]; - let arg = comptime![shape.index(i.clone())]; - let shape_i = read_scalar_shape(inputs, comptime![arg.clone()]); - - let ogwl = index / layout.stride(i); - offset += ogwl % shape_i * stride_curr; - - stride_curr *= shape_i; - } - - offset -} - -#[cube] -fn reshaped_index_to_original_index( - original: &Tensor>, - index_reshaped: u32, - #[comptime] rank: u32, -) -> u32 { - let mut remaining = index_reshaped; - let mut offset = 0; - - #[unroll] - for r in 0..rank { - let i = comptime![reverse_index(rank, r)]; - let shape = original.shape(comptime![i.clone()]); - let stride = original.stride(i); - - let coordinate = remaining % shape; - - remaining /= shape; - offset += coordinate * stride; - } - - offset / original.line_size() -} - -fn reverse_index>>( - rank: u32, - iter: Elem, -) -> ExpandElementTyped { - let elem = iter.into(); - let elem = elem.constant().map(|cons| cons.as_u32()).unwrap(); - let result = rank - elem - 1; - let scalar: Variable = result.into(); - let expand: ExpandElement = ExpandElement::Plain(scalar); - - expand.into() -} -/// Generic way to construct any [`CubePrimitive`] from an int. Used for fusion. -fn from_const_int(_value: u32) -> C { - unexpanded!() -} - -mod from_const_int { - use cubecl::ir::{ExpandElement, Scope, Variable}; - - use cubecl::prelude::ExpandElementTyped; - - use super::CubePrimitive; - - pub fn expand(scope: &mut Scope, value: u32) -> ExpandElementTyped { - let constant: ExpandElement = value.into(); - let constant_c = constant.as_const().unwrap().cast_to(C::as_elem(scope)); - ExpandElement::Plain(Variable::constant(constant_c)).into() - } -} diff --git a/crates/burn-cubecl/src/fusion/on_write/kernel.rs b/crates/burn-cubecl/src/fusion/on_write/kernel.rs deleted file mode 100644 index 73a39793fe..0000000000 --- a/crates/burn-cubecl/src/fusion/on_write/kernel.rs +++ /dev/null @@ -1,1325 +0,0 @@ -use super::io::*; -use super::ir::*; -use cubecl::prelude::*; -use half::{bf16, f16}; - -#[cube] -/// Fuse element-wise operations at the given write position. -/// -/// You can start by writing some elements using `write_values` and `write_args`. -pub fn fuse_on_write( - inputs: &GlobalArgs, - outputs: &mut GlobalArgs, - write_pos: u32, - write_values: Registry>, - #[comptime] write_args: Sequence, - #[comptime] config: &ElemwiseConfig, -) { - let mut locals = LocalArgs { - l_f32: Registry::>::new(), - l_f16: Registry::>::new(), - l_bf16: Registry::>::new(), - l_i64: Registry::>::new(), - l_i32: Registry::>::new(), - l_i16: Registry::>::new(), - l_i8: Registry::>::new(), - l_u64: Registry::>::new(), - l_u32: Registry::>::new(), - l_u16: Registry::>::new(), - l_u8: Registry::>::new(), - l_bool: Registry::>::new(), - }; - - // Write the values given as arguments. - #[unroll] - for i in 0..write_args.len() { - let arg = comptime![write_args.index(i).clone()]; - let val = write_values.find(comptime![arg.clone()]); - - write::(inputs, outputs, &mut locals, write_pos, val, arg, config); - } - - #[unroll] - for index in 0..config.ops.len() { - let op = comptime! { config.ops.index(index).clone() }; - - match op { - ElemwiseOp::Add(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - add::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - add::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - add::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I64 => { - add::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I32 => { - add::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I16 => { - add::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I8 => { - add::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U64 => { - add::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U32 => { - add::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U16 => { - add::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U8 => { - add::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Div(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - div::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - div::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - div::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I64 => { - div::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I32 => { - div::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I16 => { - div::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I8 => { - div::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U64 => { - div::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U32 => { - div::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U16 => { - div::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U8 => { - div::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Sub(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - sub::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - sub::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - sub::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I64 => { - sub::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I32 => { - sub::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I16 => { - sub::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I8 => { - sub::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U64 => { - sub::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U32 => { - sub::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U16 => { - sub::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U8 => { - sub::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Mul(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - mul::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - mul::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - mul::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I64 => { - mul::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I32 => { - mul::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I16 => { - mul::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I8 => { - mul::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U64 => { - mul::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U32 => { - mul::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U16 => { - mul::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U8 => { - mul::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Powf(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - powf::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - powf::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - powf::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Erf(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - erf::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - erf::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - erf::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Abs(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - abs::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - abs::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - abs::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U64 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U32 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U16 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U8 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I64 => { - abs::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I32 => { - abs::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I16 => { - abs::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I8 => { - abs::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Log(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - log::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - log::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - log::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Log1p(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - log1p::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - log1p::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - log1p::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Recip(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - recip::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - recip::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - recip::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Assign(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I64 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I32 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I16 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I8 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U64 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U32 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U16 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U8 => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::Bool => { - assign::(inputs, outputs, &mut locals, write_pos, op, config) - } - }, - ElemwiseOp::Exp(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - exp::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - exp::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - exp::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Cos(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - cos::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - cos::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - cos::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Sin(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - sin::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - sin::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - sin::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Tanh(op) => match op.out.precision() { - ElemwisePrecision::F32 => { - tanh::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - tanh::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - tanh::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Equal(op) => match op.lhs.precision() { - ElemwisePrecision::F32 => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I64 => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I32 => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I16 => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I8 => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U64 => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U32 => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U16 => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U8 => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::Bool => { - equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - }, - ElemwiseOp::Greater(op) => match op.lhs.precision() { - ElemwisePrecision::F32 => { - greater::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - greater::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - greater::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I64 => { - greater::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I32 => { - greater::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I16 => { - greater::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I8 => { - greater::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U64 => { - greater::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U32 => { - greater::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U16 => { - greater::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U8 => { - greater::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::GreaterEqual(op) => match op.lhs.precision() { - ElemwisePrecision::F32 => { - greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I64 => { - greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I32 => { - greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I16 => { - greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I8 => { - greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U64 => { - greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U32 => { - greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U16 => { - greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U8 => { - greater_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::Lower(op) => match op.lhs.precision() { - ElemwisePrecision::F32 => { - lower::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - lower::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - lower::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I64 => { - lower::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I32 => { - lower::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I16 => { - lower::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I8 => { - lower::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U64 => { - lower::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U32 => { - lower::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U16 => { - lower::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U8 => { - lower::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::LowerEqual(op) => match op.lhs.precision() { - ElemwisePrecision::F32 => { - lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::F16 => { - lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::BF16 => { - lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I64 => { - lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I32 => { - lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I16 => { - lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::I8 => { - lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U64 => { - lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U32 => { - lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U16 => { - lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - ElemwisePrecision::U8 => { - lower_equal::(inputs, outputs, &mut locals, write_pos, op, config) - } - _ => comptime![panic!("Unsupported precision {op:?}")], - }, - ElemwiseOp::ConditionalAssign { - cond, - lhs, - rhs, - out, - } => match out.precision() { - ElemwisePrecision::F32 => conditional_assign::( - inputs, - outputs, - &mut locals, - write_pos, - cond, - lhs, - rhs, - out, - config, - ), - ElemwisePrecision::F16 => conditional_assign::( - inputs, - outputs, - &mut locals, - write_pos, - cond, - lhs, - rhs, - out, - config, - ), - ElemwisePrecision::BF16 => conditional_assign::( - inputs, - outputs, - &mut locals, - write_pos, - cond, - lhs, - rhs, - out, - config, - ), - ElemwisePrecision::I64 => conditional_assign::( - inputs, - outputs, - &mut locals, - write_pos, - cond, - lhs, - rhs, - out, - config, - ), - ElemwisePrecision::I32 => conditional_assign::( - inputs, - outputs, - &mut locals, - write_pos, - cond, - lhs, - rhs, - out, - config, - ), - ElemwisePrecision::I16 => conditional_assign::( - inputs, - outputs, - &mut locals, - write_pos, - cond, - lhs, - rhs, - out, - config, - ), - ElemwisePrecision::I8 => conditional_assign::( - inputs, - outputs, - &mut locals, - write_pos, - cond, - lhs, - rhs, - out, - config, - ), - ElemwisePrecision::U64 => conditional_assign::( - inputs, - outputs, - &mut locals, - write_pos, - cond, - lhs, - rhs, - out, - config, - ), - ElemwisePrecision::U32 => conditional_assign::( - inputs, - outputs, - &mut locals, - write_pos, - cond, - lhs, - rhs, - out, - config, - ), - ElemwisePrecision::U16 => conditional_assign::( - inputs, - outputs, - &mut locals, - write_pos, - cond, - lhs, - rhs, - out, - config, - ), - ElemwisePrecision::U8 => conditional_assign::( - inputs, - outputs, - &mut locals, - write_pos, - cond, - lhs, - rhs, - out, - config, - ), - _ => comptime![panic!("Unsupported precision")], - }, - ElemwiseOp::Gather { - input, - indices, - output, - dim, - } => match output.precision() { - ElemwisePrecision::F32 => gather::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::F16 => gather::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::BF16 => gather::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::I64 => gather::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::I32 => gather::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::I16 => gather::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::I8 => gather::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::U64 => gather::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::U32 => gather::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::U16 => gather::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::U8 => gather::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - _ => comptime![panic!("Unsupported precision")], - }, - ElemwiseOp::Select { - input, - indices, - output, - dim, - } => match output.precision() { - ElemwisePrecision::F32 => select_indices::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::F16 => select_indices::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::BF16 => select_indices::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::I64 => select_indices::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::I32 => select_indices::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::I16 => select_indices::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::I8 => select_indices::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::U64 => select_indices::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::U32 => select_indices::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::U16 => select_indices::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - ElemwisePrecision::U8 => select_indices::( - inputs, - outputs, - &mut locals, - write_pos, - dim, - input, - indices, - output, - config, - ), - _ => comptime![panic!("Unsupported precision")], - }, - } - } -} - -macro_rules! binary_op { - ($ident:ident, $op:tt) => { - #[cube] - fn $ident( - inputs: &GlobalArgs, - outputs: &mut GlobalArgs, - locals: &mut LocalArgs, - write_pos: u32, - #[comptime] op: BinaryElemwiseArgs, - #[comptime] config: &ElemwiseConfig, - ) { - let lhs = read::(inputs, outputs, &locals, write_pos, op.lhs, config); - let rhs = read::(inputs, outputs, &locals, write_pos, op.rhs, config); - let result = lhs $op rhs; - - write::(inputs, outputs, locals, write_pos, result, op.out, config); - } - }; -} - -macro_rules! binary_func { - ($ident:ident, $func:expr, $c:tt) => { - #[cube] - fn $ident( - inputs: &GlobalArgs, - outputs: &mut GlobalArgs, - locals: &mut LocalArgs, - write_pos: u32, - #[comptime] op: BinaryElemwiseArgs, - #[comptime] config: &ElemwiseConfig, - ) { - let lhs = read::(inputs, outputs, &locals, write_pos, op.lhs, config); - let rhs = read::(inputs, outputs, &locals, write_pos, op.rhs, config); - let result = $func(lhs, rhs); - - write::(inputs, outputs, locals, write_pos, result, op.out, config); - } - }; -} - -macro_rules! comparison_op { - ($ident:ident, $op:tt) => { - #[cube] - fn $ident( - inputs: &GlobalArgs, - outputs: &mut GlobalArgs, - locals: &mut LocalArgs, - write_pos: u32, - #[comptime] op: BinaryElemwiseArgs, - #[comptime] config: &ElemwiseConfig, - ) { - let lhs = read::(inputs, outputs, &locals, write_pos, op.lhs, config); - let rhs = read::(inputs, outputs, &locals, write_pos, op.rhs, config); - let result = Line::new(lhs $op rhs); - - write::(inputs, outputs, locals, write_pos, result, op.out, config); - } - }; -} - -macro_rules! unary_func { - ($ident:ident, $func:expr, $c:tt) => { - #[cube] - fn $ident( - inputs: &GlobalArgs, - outputs: &mut GlobalArgs, - locals: &mut LocalArgs, - write_pos: u32, - #[comptime] op: UnaryElemwiseArgs, - #[comptime] config: &ElemwiseConfig, - ) { - let input = read::(inputs, outputs, &locals, write_pos, op.input, config); - let result = $func(input); - - write::(inputs, outputs, locals, write_pos, result, op.out, config); - } - }; -} - -#[cube] -fn assign( - inputs: &GlobalArgs, - outputs: &mut GlobalArgs, - locals: &mut LocalArgs, - write_pos: u32, - #[comptime] op: UnaryElemwiseArgs, - #[comptime] config: &ElemwiseConfig, -) { - let input = read::(inputs, outputs, locals, write_pos, op.input, config); - - write::(inputs, outputs, locals, write_pos, input, op.out, config); -} - -#[cube] -fn gather( - inputs: &GlobalArgs, - outputs: &mut GlobalArgs, - locals: &mut LocalArgs, - write_pos: u32, - #[comptime] dim: u32, - #[comptime] input: Arg, - #[comptime] indices: Arg, - #[comptime] output: Arg, - #[comptime] config: &ElemwiseConfig, -) { - let mut index = read::(inputs, outputs, locals, write_pos, indices, config); - let (pos, precision) = comptime! { - match input { - Arg::Input(pos, precision, _) => (pos, precision), - _ => panic!("Input tensor isn't an input"), - } - }; - let line_size = match config.ref_layout { - Arg::Input(pos, precision, _) => global_line_size(inputs, pos, precision), - Arg::Output(pos, precision, _) => global_line_size(outputs, pos, precision), - _ => unreachable!(), - }; - let stride = global_stride(inputs, dim, pos, precision); - - index *= Line::new(stride); - - if comptime![dim > 0] { - let index_before = global_offset( - inputs, - outputs, - write_pos, - comment!(input.clone()), - comptime![Some((0u32, dim))], - config, - ); - index += Line::new(index_before); - } - - if comptime![dim + 1 < config.rank] { - let index_after = global_offset( - inputs, - outputs, - write_pos, - input, - comptime![Some((dim + 1, config.rank))], - config, - ); - index += Line::new(index_after); - } - - let mut result = Line::empty(line_size); - - #[unroll] - for i in 0..line_size { - let index = index[i]; - - let input = read_input::( - inputs, - outputs, - pos, - index, - LayoutInfo::IsRef, - precision, - config, - None, - ); - result[i] = input[0]; - } - - write::(inputs, outputs, locals, write_pos, result, output, config); -} - -#[cube] -fn select_indices( - inputs: &GlobalArgs, - outputs: &mut GlobalArgs, - locals: &mut LocalArgs, - write_pos: u32, - #[comptime] dim: u32, - #[comptime] input: Arg, - #[comptime] indices: Arg, - #[comptime] output: Arg, - #[comptime] config: &ElemwiseConfig, -) { - let (line_size_ref, stride_dim_ref, shape_dim_ref) = match config.ref_layout { - Arg::Input(pos, precision, _) => ( - global_line_size(inputs, pos, precision), - global_stride(inputs, dim, pos, precision), - global_shape(inputs, dim, pos, precision), - ), - Arg::Output(pos, precision, _) => ( - global_line_size(outputs, pos, precision), - global_stride(outputs, dim, pos, precision), - global_shape(outputs, dim, pos, precision), - ), - _ => unreachable!(), - }; - - let (pos_input, precision_input) = comptime! { - match input { - Arg::Input(pos, precision, _) => (pos, precision), - _ => panic!("Input tensor isn't an input"), - } - }; - let (pos_indices, precision_indices) = match indices { - Arg::Input(pos, precision, ..) => (pos, precision), - _ => panic!("Indices tensor isn't an input"), - }; - - let stride_input_dim = global_stride(inputs, dim, pos_input, precision_input); - - let mut index = 0u32; - let mut result = Line::empty(line_size_ref); - - if comptime![dim != config.rank - 1] { - // In this scenario the select is actually broadcasted along the axis we're working on. - // - // Therefore the same indices are used to fetch multiple entries in the input tensor. - - let write_pos_input = write_pos * line_size_ref; - let stride_input_line = global_stride( - inputs, - comptime![config.rank - 1], - pos_input, - precision_input, - ); - - if comptime![dim > 0] { - let index_before = global_offset( - inputs, - outputs, - write_pos_input, - comment!(input.clone()), - comptime![Some((0u32, dim))], - config, - ); - index += index_before; - } - - if comptime![dim + 1 < config.rank] { - let index_after = global_offset( - inputs, - outputs, - write_pos_input, - comment!(input.clone()), - comptime![Some((dim + 1, config.rank))], - config, - ); - index += index_after; - } - - let coordinate_dim = write_pos_input / stride_dim_ref % shape_dim_ref; - let offset_dim = read_input::( - inputs, - outputs, - pos_indices, - coordinate_dim, - LayoutInfo::IsRef, - precision_indices, - config, - None, - ); - - index *= line_size_ref; - index += offset_dim[0] * stride_input_dim; - - #[unroll] - for i in 0..line_size_ref { - let input = read_input::( - inputs, - outputs, - pos_input, - index + i * stride_input_line, - LayoutInfo::IsRef, - precision_input, - config, - None, - ); - result[i] = input[0]; - } - } else { - // In this scenario the select is actually performed on the last dimension we're working on. - // - // Therefore we need to fetch multiple indices that correspond to different entries in the - // input tensor. - - if comptime![dim > 0] { - let index_before = global_offset( - inputs, - outputs, - write_pos, - comment!(input.clone()), - comptime![Some((0u32, dim))], - config, - ); - index += index_before; - } - - if comptime![dim + 1 < config.rank] { - let index_after = global_offset( - inputs, - outputs, - write_pos, - input, - comptime![Some((dim + 1, config.rank))], - config, - ); - index += index_after; - } - - let write_pos_indices = write_pos * line_size_ref; - - #[unroll] - for i in 0..line_size_ref { - let coordinate_dim = (write_pos_indices + i) / stride_dim_ref % shape_dim_ref; - let offset_dim = read_input::( - inputs, - outputs, - pos_indices, - coordinate_dim, - LayoutInfo::IsRef, - precision_indices, - config, - None, - ); - - let input = read_input::( - inputs, - outputs, - pos_input, - index + (offset_dim[0] * stride_input_dim), - LayoutInfo::IsRef, - precision_input, - config, - None, - ); - result[i] = input[0]; - } - } - - write::(inputs, outputs, locals, write_pos, result, output, config); -} - -#[cube] -fn conditional_assign( - inputs: &GlobalArgs, - outputs: &mut GlobalArgs, - locals: &mut LocalArgs, - write_pos: u32, - #[comptime] cond: Arg, - #[comptime] lhs: Arg, - #[comptime] rhs: Arg, - #[comptime] out: Arg, - #[comptime] config: &ElemwiseConfig, -) { - let cond = read::(inputs, outputs, locals, write_pos, cond, config); - let lhs = read::(inputs, outputs, locals, write_pos, lhs, config); - let rhs = read::(inputs, outputs, locals, write_pos, rhs, config); - let result = select_many(cond, lhs, rhs); - - write::(inputs, outputs, locals, write_pos, result, out, config); -} - -binary_op!(add, +); -binary_op!(mul, *); -binary_op!(div, /); -binary_op!(sub, -); - -comparison_op!(equal, ==); -comparison_op!(greater, >); -comparison_op!(greater_equal, >=); -comparison_op!(lower, <); -comparison_op!(lower_equal, <=); - -binary_func!(powf, Line::::powf, Float); - -unary_func!(exp, Line::::exp, Float); -unary_func!(log, Line::::log, Float); -unary_func!(log1p, Line::::log1p, Float); -unary_func!(cos, Line::::cos, Float); -unary_func!(sin, Line::::sin, Float); -unary_func!(tanh, Line::::tanh, Float); -unary_func!(erf, Line::::erf, Float); -unary_func!(recip, Line::::recip, Float); -unary_func!(abs, Line::::abs, Numeric); diff --git a/crates/burn-cubecl/src/fusion/on_write/trace/executor.rs b/crates/burn-cubecl/src/fusion/on_write/trace/executor.rs deleted file mode 100644 index 1b07fdcda6..0000000000 --- a/crates/burn-cubecl/src/fusion/on_write/trace/executor.rs +++ /dev/null @@ -1,230 +0,0 @@ -use std::{collections::BTreeMap, marker::PhantomData}; - -use burn_fusion::stream::Context; -use burn_tensor::DType; -use cubecl::{ - client::ComputeClient, - prelude::{ScalarArg, Sequence, TensorArg}, -}; - -use super::{HandleInput, HandleOutput, LaunchPlan, TensorView, TraceRunner}; -use crate::{ - fusion::{ - on_write::ir::{ElemwiseConfig, ElemwiseOp, ElemwisePrecision, GlobalArgsLaunch}, - CubeFusionHandle, - }, - BoolElement, CubeRuntime, -}; - -/// Execute a [plan](LaunchPlan) using a [runner](TraceRunner) modifying the [context](Context). -pub struct LaunchPlanExecutor<'a, R: CubeRuntime> { - scalars: &'a BTreeMap, - views: &'a Vec, - ops: &'a Vec, - _r: PhantomData, -} - -#[derive(new)] -pub struct ExecutionError> { - pub runner_error: Runner::Error, - pub handles_input: Vec>, - pub handles_output: Vec>, -} - -impl<'a, R: CubeRuntime> LaunchPlanExecutor<'a, R> { - pub fn new( - scalars: &'a BTreeMap, - views: &'a Vec, - ops: &'a Vec, - ) -> Self { - Self { - scalars, - views, - ops, - _r: PhantomData, - } - } - - pub fn execute, BT: BoolElement>( - self, - client: &ComputeClient, - runner: &Runner, - context: &mut Context<'_, CubeFusionHandle>, - plan: LaunchPlan<'a, R>, - ) -> Result<(), ExecutionError> { - let reference = match plan.reference { - Some(reference) => reference, - None => { - if plan.writes.is_empty() { - // Nothing to write, can skip execution. - return Ok(()); - } else { - panic!("An output should exist for the fused kernel") - } - } - }; - - let inputs = self.register_inputs(context, &plan.handle_inputs); - let outputs = self.register_outputs::(&plan.handle_outputs); - - let mut ops = Sequence::::new(); - - for read_ops in plan.reads.into_values() { - for op in read_ops { - ops.push(op); - } - } - - for op in self.ops.iter() { - ops.push(op.clone()); - } - - for op in plan.writes.into_values() { - ops.push(op); - } - - let config = ElemwiseConfig { - rank: plan.rank as u32, - ref_layout: reference.layout, - ops, - }; - - Runner::run(runner, client, inputs, outputs, &config) - .map_err(|err| ExecutionError::new(err, plan.handle_inputs, plan.handle_outputs)) - } - - fn register_inputs<'h>( - &self, - context: &mut Context<'_, CubeFusionHandle>, - handle_inputs: &'h [HandleInput], - ) -> GlobalArgsLaunch<'h, R> { - let mut inputs = GlobalArgsLaunch::default(); - - for hi in handle_inputs.iter() { - let arg = hi.handle.as_tensor_arg(&hi.global_shape, hi.vectorization); - match hi.precision { - ElemwisePrecision::F32 => inputs.t_f32.push(arg), - ElemwisePrecision::F16 => inputs.t_f16.push(arg), - ElemwisePrecision::BF16 => inputs.t_bf16.push(arg), - ElemwisePrecision::I64 => inputs.t_i64.push(arg), - ElemwisePrecision::I32 => inputs.t_i32.push(arg), - ElemwisePrecision::I16 => inputs.t_i16.push(arg), - ElemwisePrecision::I8 => inputs.t_i8.push(arg), - ElemwisePrecision::U64 => inputs.t_u64.push(arg), - ElemwisePrecision::U32 => inputs.t_u32.push(arg), - ElemwisePrecision::U16 => inputs.t_u16.push(arg), - ElemwisePrecision::U8 => inputs.t_u8.push(arg), - _ => panic!("Unsupported input precision {:?}", hi.precision), - }; - } - - for (precision, count) in self.scalars.iter() { - for i in 0..(*count as usize) { - match precision { - ElemwisePrecision::F32 => { - inputs.s_f32.push(ScalarArg::new(context.scalar_f32[i])) - } - ElemwisePrecision::F16 => { - inputs.s_f16.push(ScalarArg::new(context.scalar_f16[i])) - } - ElemwisePrecision::BF16 => { - inputs.s_bf16.push(ScalarArg::new(context.scalar_bf16[i])) - } - ElemwisePrecision::I64 => { - inputs.s_i64.push(ScalarArg::new(context.scalar_i64[i])) - } - ElemwisePrecision::I32 => { - inputs.s_i32.push(ScalarArg::new(context.scalar_i32[i])) - } - ElemwisePrecision::I16 => { - inputs.s_i16.push(ScalarArg::new(context.scalar_i16[i])) - } - ElemwisePrecision::I8 => inputs.s_i8.push(ScalarArg::new(context.scalar_i8[i])), - ElemwisePrecision::U64 => { - inputs.s_u64.push(ScalarArg::new(context.scalar_u64[i])) - } - ElemwisePrecision::U32 => { - inputs.s_u32.push(ScalarArg::new(context.scalar_u32[i])) - } - ElemwisePrecision::U16 => { - inputs.s_u16.push(ScalarArg::new(context.scalar_u16[i])) - } - ElemwisePrecision::U8 => inputs.s_u8.push(ScalarArg::new(context.scalar_u8[i])), - ElemwisePrecision::Bool => todo!(), - } - } - } - - // Reshape values are pushed in reverse in the same scalar buffer for all `u32` - for relative in self.views.iter().rev() { - if let TensorView::Reshape { reshaped, .. } = relative { - let global = context.tensors.get(reshaped).unwrap(); - - for shape in global.shape.iter().rev() { - inputs.s_u32.push(ScalarArg::new(*shape as u32)) - } - } - } - - inputs - } - - fn register_outputs<'s, BT: BoolElement>( - &self, - handle_outputs: &'s [HandleOutput], - ) -> GlobalArgsLaunch<'s, R> { - let mut outputs = GlobalArgsLaunch::default(); - - for item in handle_outputs.iter() { - match item { - HandleOutput::Alias { - input_pos, - precision, - } => match precision { - ElemwisePrecision::F32 => outputs.t_f32.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::F16 => outputs.t_f16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::BF16 => outputs.t_bf16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I64 => outputs.t_i64.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I32 => outputs.t_i32.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I16 => outputs.t_i16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I8 => outputs.t_i8.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U64 => outputs.t_u64.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U32 => outputs.t_u32.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U16 => outputs.t_u16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U8 => outputs.t_u8.push(TensorArg::alias(*input_pos)), - _ => todo!(), - }, - HandleOutput::Owned { - precision, - handle, - global_shape, - vectorization, - .. - } => { - let arg = handle.as_tensor_arg(global_shape, *vectorization); - - match precision { - ElemwisePrecision::F32 => outputs.t_f32.push(arg), - ElemwisePrecision::F16 => outputs.t_f16.push(arg), - ElemwisePrecision::BF16 => outputs.t_bf16.push(arg), - ElemwisePrecision::I64 => outputs.t_i64.push(arg), - ElemwisePrecision::I32 => outputs.t_i32.push(arg), - ElemwisePrecision::I16 => outputs.t_i16.push(arg), - ElemwisePrecision::I8 => outputs.t_i8.push(arg), - ElemwisePrecision::U64 => outputs.t_u64.push(arg), - ElemwisePrecision::U32 => outputs.t_u32.push(arg), - ElemwisePrecision::U16 => outputs.t_u16.push(arg), - ElemwisePrecision::U8 => outputs.t_u8.push(arg), - ElemwisePrecision::Bool => match BT::dtype() { - DType::U32 => outputs.t_u32.push(arg), - DType::U8 => outputs.t_u8.push(arg), - _ => todo!(), - }, - }; - } - } - } - - outputs - } -} diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/homogeneous/base.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/homogeneous/base.rs index 8e36462bd3..8e1e407add 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/homogeneous/base.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/homogeneous/base.rs @@ -12,7 +12,7 @@ use cubecl::{ stage::{ self, multi_buffer::{LhsReader, LhsReaderFamily, RhsReader, RhsReaderFamily}, - StageMatmulFamily, TilingOrderConfig, + StageMatmulFamily, TilingLayout, }, Ident, InvalidConfigError, MatrixLayout, StageTiling, }, @@ -357,8 +357,8 @@ pub mod config { self.matmul.plane_dim() } - fn tiling_order(&self, ident: Ident) -> TilingOrderConfig { - self.matmul.tiling_order(ident) + fn tiling_layout(&self, ident: Ident) -> TilingLayout { + self.matmul.tiling_layout(ident) } fn check_row_bounds(&self, ident: Ident) -> bool { diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/loader/im2col.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/loader/im2col.rs index 471a122888..0c7a7b789e 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/loader/im2col.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/loader/im2col.rs @@ -1,18 +1,15 @@ +use cubecl::prelude::pipeline::Pipeline; use cubecl::{ linalg::{ matmul::components::{ global::InputLoader, - stage::{ - multi_buffer::LhsReader, ColMajorTiling, RowMajorTiling, Stage, TilingOrder as _, - TilingOrderConfig, - }, + stage::{multi_buffer::LhsReader, Stage, TilingLayout}, Ident, }, tensor::VirtualTensor, }, prelude::*, }; -use cubecl::prelude::pipeline::Pipeline; use std::marker::PhantomData; use crate::kernel::conv::{precision::ConvPrecision, reader::im2col::Im2colReader, ConvGemmConfig}; @@ -124,18 +121,12 @@ impl SimpleIm2col { let nth_tile = unit_position / tile_num_elements; let pos_within_tile = unit_position % tile_num_elements; - let (tile_x, tile_y) = match config.tiling_order(ident) { - TilingOrderConfig::RowMajor => RowMajorTiling::to_x_y( - nth_tile, - stage_tiling.tile_count_row(), - stage_tiling.tile_count_col(), - ), - TilingOrderConfig::ColMajor => ColMajorTiling::to_x_y( - nth_tile, - stage_tiling.tile_count_row(), - stage_tiling.tile_count_col(), - ), - }; + let (tile_x, tile_y) = TilingLayout::to_x_y( + config.tiling_layout(ident), + nth_tile, + stage_tiling.tile_count_row(), + stage_tiling.tile_count_col(), + ); let line_read = read_view.load_simple::(tile_x, tile_y, pos_within_tile, ident, config); diff --git a/crates/burn-cubecl/src/lib.rs b/crates/burn-cubecl/src/lib.rs index 17ae9983b3..776e724c21 100644 --- a/crates/burn-cubecl/src/lib.rs +++ b/crates/burn-cubecl/src/lib.rs @@ -18,7 +18,6 @@ pub mod tensor; /// Elements for JIT backend pub mod element; -use burn_tensor::backend::{DeviceId, DeviceOps}; use cubecl::{compute::CubeTask, Feature, Runtime}; pub use element::{BoolElement, CubeElement, FloatElement, IntElement}; @@ -54,28 +53,4 @@ pub trait CubeRuntime: Runtime; } -/// ID used to identify a Just-in-Time environment. -#[derive(Hash, PartialEq, Eq, Debug, Clone)] -pub struct CubeTuneId { - device: DeviceId, - name: &'static str, -} - -impl CubeTuneId { - /// Create a new ID. - pub fn new(device: &R::Device) -> Self { - Self { - device: DeviceOps::id(device), - name: R::name(), - } - } -} - -impl core::fmt::Display for CubeTuneId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!( - "device-{}-{}-{}", - self.device.type_id, self.device.index_id, self.name - )) - } -} +pub use cubecl::CubeTuneId; diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index b0abf7d178..a201d8f728 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -13,9 +13,49 @@ rust-version = "1.82" version.workspace = true [features] -default = ["burn-core/default", "burn-train?/default", "std"] -doc = ["default", "train", "burn-core/doc", "burn-train/doc"] -std = ["burn-core/std"] +default = [ + "burn-core/default", + "burn-train?/default", + "std", + # Backends + "burn-candle?/default", + "burn-ndarray?/default", + "burn-tch?/default", + "burn-wgpu?/default", + "burn-router?/default", + "burn-cuda?/default", + "burn-autodiff?/default", + "burn-hip?/default", +] +doc = [ + "default", + "train", + "burn-core/doc", + "burn-train/doc", + # Backends + "burn-candle/doc", + "burn-ndarray/doc", + "burn-tch/doc", + "burn-wgpu/doc", + "burn-router/doc", + "burn-cuda/doc", + "burn-autodiff?/std", + "burn-hip/doc", + +] +std = [ + "burn-core/std", + # Backends + "burn-candle?/std", + "burn-ndarray?/std", + "burn-wgpu?/std", + "burn-router?/std", + "burn-cuda?/std", + "burn-autodiff?/std", + "burn-hip?/std", +] + +network = ["burn-core/network"] # Training with full features train = ["burn-train", "autodiff", "dataset"] @@ -32,43 +72,40 @@ dataset = ["burn-core/dataset"] sqlite = ["burn-core/sqlite"] sqlite-bundled = ["burn-core/sqlite-bundled"] +# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files. +record-item-custom-serde = ["burn-core/record-item-custom-serde"] +# Serialization formats +experimental-named-tensor = ["burn-core/experimental-named-tensor"] + + audio = ["burn-core/audio"] vision = ["burn-core/vision"] -# Backends -autodiff = ["burn-core/autodiff"] -fusion = ["burn-core/fusion"] +# Backend +autodiff = ["burn-autodiff"] +fusion = ["burn-wgpu?/fusion", "burn-cuda?/fusion"] ## Backend features -accelerate = ["burn-core/accelerate"] -autotune = ["burn-core/autotune"] -blas-netlib = ["burn-core/blas-netlib"] -candle-cuda = ["burn-core/candle-cuda"] -metal = ["burn-core/metal"] -openblas = ["burn-core/openblas"] -openblas-system = ["burn-core/openblas-system"] -template = ["burn-core/template"] - -candle = ["burn-core/candle"] -cuda = ["burn-core/cuda"] -hip = ["burn-core/hip"] -ndarray = ["burn-core/ndarray"] -remote = ["burn-core/remote"] -router = ["burn-core/router"] -server = ["burn-core/server"] -tch = ["burn-core/tch"] -wgpu = ["burn-core/wgpu"] -vulkan = ["burn-core/vulkan"] -webgpu = ["burn-core/webgpu"] - -# Network utils -network = ["burn-core/network"] - -# Experimental -experimental-named-tensor = ["burn-core/experimental-named-tensor"] - -# Records -record-item-custom-serde = ["burn-core/record-item-custom-serde"] +accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"] +autotune = ["burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-hip?/autotune"] +blas-netlib = ["burn-ndarray?/blas-netlib"] +metal = ["burn-candle?/metal"] +openblas = ["burn-ndarray?/blas-openblas"] +openblas-system = ["burn-ndarray?/blas-openblas-system"] +remote = ["burn-remote/client"] +router = ["burn-router"] +server = ["burn-remote/server"] +template = ["burn-wgpu?/template"] + +candle = ["burn-candle"] +candle-cuda = ["candle", "burn-candle/cuda"] +cuda = ["burn-cuda"] +hip = ["burn-hip"] +ndarray = ["burn-ndarray"] +tch = ["burn-tch"] +wgpu = ["burn-wgpu"] +vulkan = ["wgpu", "burn-wgpu/vulkan"] +webgpu = ["wgpu", "burn-wgpu/webgpu"] [dependencies] @@ -76,3 +113,14 @@ record-item-custom-serde = ["burn-core/record-item-custom-serde"] burn-core = { path = "../burn-core", version = "0.17.0", default-features = false } burn-train = { path = "../burn-train", version = "0.17.0", optional = true, default-features = false } + +# Backends +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } +burn-candle = { path = "../burn-candle", version = "0.17.0", optional = true } +burn-cuda = { path = "../burn-cuda", version = "0.17.0", optional = true, default-features = false } +burn-hip = { path = "../burn-hip", version = "0.17.0", optional = true, default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true, default-features = false } +burn-remote = { path = "../burn-remote", version = "0.17.0", default-features = false, optional = true } +burn-router = { path = "../burn-router", version = "0.17.0", default-features = false, optional = true } +burn-tch = { path = "../burn-tch", version = "0.17.0", optional = true } +burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", optional = true, default-features = false } diff --git a/crates/burn-core/src/backend.rs b/crates/burn/src/backend.rs similarity index 100% rename from crates/burn-core/src/backend.rs rename to crates/burn/src/backend.rs diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs index 203d1a802d..f12e6e00cc 100644 --- a/crates/burn/src/lib.rs +++ b/crates/burn/src/lib.rs @@ -107,3 +107,9 @@ pub use burn_core::*; pub mod train { pub use burn_train::*; } + +/// Backend module. +pub mod backend; + +#[cfg(feature = "server")] +pub use burn_remote::server;