Skip to content

Commit bca6231

Browse files
authored
Introduce cuda.cooperative overloads not requiring temporary storage (NVIDIA#2528)
* Modernize pkg resource query * Add cooperative overloads without shared memory * Start fixing temp storage * Incorporate template params into mangling * Condence dict access * Fix temporary storage indexing for sub hw waprs * Test multiple warps * Disable alloc API for sub hw warps
1 parent 3892c32 commit bca6231

14 files changed

+303
-155
lines changed

python/cuda_cooperative/cuda/cooperative/experimental/_nvrtc.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import importlib.resources as pkg_resources
1111
import functools
1212

13+
1314
def CHECK_NVRTC(err, prog):
1415
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
1516
err, logsize = nvrtc.nvrtcGetProgramLogSize(prog)
@@ -39,8 +40,8 @@ def get_cuda_path():
3940
# rdc is true or false
4041
# code is lto or ptx
4142
# @cache
42-
@functools.lru_cache(maxsize=32) # Always enabled
43-
@disk_cache # Optional, see caching.py
43+
@functools.lru_cache(maxsize=32) # Always enabled
44+
@disk_cache # Optional, see caching.py
4445
def compile_impl(cpp, cc, rdc, code, nvrtc_path, nvrtc_version):
4546
check_in('rdc', rdc, [True, False])
4647
check_in('code', code, ['lto', 'ptx'])
@@ -54,11 +55,11 @@ def compile_impl(cpp, cc, rdc, code, nvrtc_path, nvrtc_version):
5455
libcudacxx_path = os.path.join(include_path, 'libcudacxx')
5556
cuda_include_path = os.path.join(get_cuda_path(), 'include')
5657

57-
opts = [b"--std=c++17", \
58-
bytes(f"--include-path={cub_path}", encoding='ascii'), \
59-
bytes(f"--include-path={thrust_path}", encoding='ascii'), \
60-
bytes(f"--include-path={libcudacxx_path}", encoding='ascii'), \
61-
bytes(f"--include-path={cuda_include_path}", encoding='ascii'), \
58+
opts = [b"--std=c++17",
59+
bytes(f"--include-path={cub_path}", encoding='ascii'),
60+
bytes(f"--include-path={thrust_path}", encoding='ascii'),
61+
bytes(f"--include-path={libcudacxx_path}", encoding='ascii'),
62+
bytes(f"--include-path={cuda_include_path}", encoding='ascii'),
6263
bytes(f"--gpu-architecture=compute_{cc}", encoding='ascii')]
6364
if rdc:
6465
opts += [b"--relocatable-device-code=true"]
@@ -70,7 +71,8 @@ def compile_impl(cpp, cc, rdc, code, nvrtc_path, nvrtc_version):
7071
opts += [b"-DCCCL_DISABLE_BF16_SUPPORT"]
7172

7273
# Create program
73-
err, prog = nvrtc.nvrtcCreateProgram(str.encode(cpp), b"code.cu", 0, [], [])
74+
err, prog = nvrtc.nvrtcCreateProgram(
75+
str.encode(cpp), b"code.cu", 0, [], [])
7476
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
7577
raise RuntimeError(f"nvrtcCreateProgram error: {err}")
7678

@@ -103,12 +105,13 @@ def compile_impl(cpp, cc, rdc, code, nvrtc_path, nvrtc_version):
103105

104106
return ptx.decode('ascii')
105107

108+
106109
def compile(**kwargs):
107110

108111
err, major, minor = nvrtc.nvrtcVersion()
109112
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
110113
raise RuntimeError(f"nvrtcVersion error: {err}")
111114
nvrtc_version = version(major, minor)
112-
return nvrtc_version, compile_impl(**kwargs, \
113-
nvrtc_path=nvrtc.__file__, \
114-
nvrtc_version=nvrtc_version)
115+
return nvrtc_version, compile_impl(**kwargs,
116+
nvrtc_path=nvrtc.__file__,
117+
nvrtc_version=nvrtc_version)

0 commit comments

Comments
 (0)