|
8 | 8 | //
|
9 | 9 | //===----------------------------------------------------------------------===//
|
10 | 10 |
|
11 |
| -#include <cub/agent/single_pass_scan_operators.cuh> |
12 | 11 | #include <cub/detail/choose_offset.cuh>
|
13 | 12 | #include <cub/detail/launcher/cuda_driver.cuh>
|
14 | 13 | #include <cub/device/dispatch/dispatch_scan.cuh>
|
|
20 | 19 | #include <format>
|
21 | 20 | #include <iostream>
|
22 | 21 | #include <optional>
|
23 |
| -#include <regex> |
24 | 22 | #include <string>
|
25 | 23 | #include <type_traits>
|
26 | 24 |
|
|
30 | 28 | #include "util/context.h"
|
31 | 29 | #include "util/errors.h"
|
32 | 30 | #include "util/indirect_arg.h"
|
| 31 | +#include "util/scan_tile_state.h" |
33 | 32 | #include "util/types.h"
|
34 | 33 | #include <cccl/c/scan.h>
|
35 | 34 | #include <nvrtc.h>
|
@@ -172,74 +171,6 @@ std::string get_scan_kernel_name(cccl_iterator_t input_it, cccl_iterator_t outpu
|
172 | 171 | init_t); // 9
|
173 | 172 | }
|
174 | 173 |
|
175 |
| -// TODO: NVRTC doesn't currently support extracting basic type |
176 |
| -// information (e.g., type sizes and alignments) from compiled |
177 |
| -// LTO-IR. So we separately compile a small PTX file that defines the |
178 |
| -// necessary types and constants and grep it for the required |
179 |
| -// information. If/when NVRTC adds these features, we can remove this |
180 |
| -// extra compilation step and get the information directly from the |
181 |
| -// LTO-IR. |
182 |
| -static constexpr auto ptx_u64_assignment_regex = R"(\.visible\s+\.global\s+\.align\s+\d+\s+\.u64\s+{}\s*=\s*(\d+);)"; |
183 |
| - |
184 |
| -std::optional<size_t> find_size_t(char* ptx, std::string_view name) |
185 |
| -{ |
186 |
| - std::regex regex(std::format(ptx_u64_assignment_regex, name)); |
187 |
| - std::cmatch match; |
188 |
| - if (std::regex_search(ptx, match, regex)) |
189 |
| - { |
190 |
| - auto result = std::stoi(match[1].str()); |
191 |
| - return result; |
192 |
| - } |
193 |
| - return std::nullopt; |
194 |
| -} |
195 |
| - |
196 |
| -struct scan_tile_state |
197 |
| -{ |
198 |
| - // scan_tile_state implements the same (host) interface as cub::ScanTileStateT, except |
199 |
| - // that it accepts the acummulator type as a runtime parameter rather than being |
200 |
| - // templated on it. |
201 |
| - // |
202 |
| - // Both specializations ScanTileStateT<T, true> and ScanTileStateT<T, false> - where the |
203 |
| - // bool parameter indicates whether `T` is primitive - are combined into a single type. |
204 |
| - |
205 |
| - void* d_tile_status; // d_tile_descriptors |
206 |
| - void* d_tile_partial; |
207 |
| - void* d_tile_inclusive; |
208 |
| - |
209 |
| - size_t description_bytes_per_tile; |
210 |
| - size_t payload_bytes_per_tile; |
211 |
| - |
212 |
| - scan_tile_state(size_t description_bytes_per_tile, size_t payload_bytes_per_tile) |
213 |
| - : d_tile_status(nullptr) |
214 |
| - , d_tile_partial(nullptr) |
215 |
| - , d_tile_inclusive(nullptr) |
216 |
| - , description_bytes_per_tile(description_bytes_per_tile) |
217 |
| - , payload_bytes_per_tile(payload_bytes_per_tile) |
218 |
| - {} |
219 |
| - |
220 |
| - cudaError_t Init(int num_tiles, void* d_temp_storage, size_t temp_storage_bytes) |
221 |
| - { |
222 |
| - void* allocations[3] = {}; |
223 |
| - auto status = cub::detail::tile_state_init( |
224 |
| - description_bytes_per_tile, payload_bytes_per_tile, num_tiles, d_temp_storage, temp_storage_bytes, allocations); |
225 |
| - if (status != cudaSuccess) |
226 |
| - { |
227 |
| - return status; |
228 |
| - } |
229 |
| - d_tile_status = allocations[0]; |
230 |
| - d_tile_partial = allocations[1]; |
231 |
| - d_tile_inclusive = allocations[2]; |
232 |
| - return cudaSuccess; |
233 |
| - } |
234 |
| - |
235 |
| - cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes) const |
236 |
| - { |
237 |
| - temp_storage_bytes = |
238 |
| - cub::detail::tile_state_allocation_size(description_bytes_per_tile, payload_bytes_per_tile, num_tiles); |
239 |
| - return cudaSuccess; |
240 |
| - } |
241 |
| -}; |
242 |
| - |
243 | 174 | template <auto* GetPolicy>
|
244 | 175 | struct dynamic_scan_policy_t
|
245 | 176 | {
|
@@ -392,43 +323,8 @@ struct device_scan_policy {{
|
392 | 323 | check(cuLibraryGetKernel(&build_ptr->init_kernel, build_ptr->library, init_kernel_lowered_name.c_str()));
|
393 | 324 | check(cuLibraryGetKernel(&build_ptr->scan_kernel, build_ptr->library, scan_kernel_lowered_name.c_str()));
|
394 | 325 |
|
395 |
| - constexpr size_t num_ptx_args = 7; |
396 |
| - const char* ptx_args[num_ptx_args] = { |
397 |
| - arch.c_str(), cub_path, thrust_path, libcudacxx_path, ctk_path, "-rdc=true", "-dlto"}; |
398 |
| - constexpr size_t num_ptx_lto_args = 3; |
399 |
| - const char* ptx_lopts[num_ptx_lto_args] = {"-lto", arch.c_str(), "-ptx"}; |
400 |
| - |
401 |
| - constexpr std::string_view ptx_src_template = R"XXX( |
402 |
| -#include <cub/agent/single_pass_scan_operators.cuh> |
403 |
| -#include <cub/util_type.cuh> |
404 |
| -struct __align__({1}) storage_t {{ |
405 |
| - char data[{0}]; |
406 |
| -}}; |
407 |
| -__device__ size_t description_bytes_per_tile = cub::ScanTileState<{2}>::description_bytes_per_tile; |
408 |
| -__device__ size_t payload_bytes_per_tile = cub::ScanTileState<{2}>::payload_bytes_per_tile; |
409 |
| -)XXX"; |
410 |
| - |
411 |
| - const std::string ptx_src = std::format(ptx_src_template, accum_t.size, accum_t.alignment, accum_cpp); |
412 |
| - auto compile_result = |
413 |
| - make_nvrtc_command_list() |
414 |
| - .add_program(nvrtc_translation_unit{ptx_src.c_str(), "tile_state_info"}) |
415 |
| - .compile_program({ptx_args, num_ptx_args}) |
416 |
| - .cleanup_program() |
417 |
| - .finalize_program(num_ptx_lto_args, ptx_lopts); |
418 |
| - auto ptx_code = compile_result.data.get(); |
419 |
| - |
420 |
| - size_t description_bytes_per_tile; |
421 |
| - size_t payload_bytes_per_tile; |
422 |
| - auto maybe_description_bytes_per_tile = scan::find_size_t(ptx_code, "description_bytes_per_tile"); |
423 |
| - if (maybe_description_bytes_per_tile) |
424 |
| - { |
425 |
| - description_bytes_per_tile = maybe_description_bytes_per_tile.value(); |
426 |
| - } |
427 |
| - else |
428 |
| - { |
429 |
| - throw std::runtime_error("Failed to find description_bytes_per_tile in PTX"); |
430 |
| - } |
431 |
| - payload_bytes_per_tile = scan::find_size_t(ptx_code, "payload_bytes_per_tile").value_or(0); |
| 326 | + auto [description_bytes_per_tile, |
| 327 | + payload_bytes_per_tile] = get_tile_state_bytes_per_tile(accum_t, accum_cpp, args, num_args, arch); |
432 | 328 |
|
433 | 329 | build_ptr->cc = cc;
|
434 | 330 | build_ptr->cubin = (void*) result.data.release();
|
|
0 commit comments