10
10
import importlib .resources as pkg_resources
11
11
import functools
12
12
13
+
13
14
def CHECK_NVRTC (err , prog ):
14
15
if err != nvrtc .nvrtcResult .NVRTC_SUCCESS :
15
16
err , logsize = nvrtc .nvrtcGetProgramLogSize (prog )
@@ -39,8 +40,8 @@ def get_cuda_path():
39
40
# rdc is true or false
40
41
# code is lto or ptx
41
42
# @cache
42
- @functools .lru_cache (maxsize = 32 ) # Always enabled
43
- @disk_cache # Optional, see caching.py
43
+ @functools .lru_cache (maxsize = 32 ) # Always enabled
44
+ @disk_cache # Optional, see caching.py
44
45
def compile_impl (cpp , cc , rdc , code , nvrtc_path , nvrtc_version ):
45
46
check_in ('rdc' , rdc , [True , False ])
46
47
check_in ('code' , code , ['lto' , 'ptx' ])
@@ -54,11 +55,11 @@ def compile_impl(cpp, cc, rdc, code, nvrtc_path, nvrtc_version):
54
55
libcudacxx_path = os .path .join (include_path , 'libcudacxx' )
55
56
cuda_include_path = os .path .join (get_cuda_path (), 'include' )
56
57
57
- opts = [b"--std=c++17" , \
58
- bytes (f"--include-path={ cub_path } " , encoding = 'ascii' ), \
59
- bytes (f"--include-path={ thrust_path } " , encoding = 'ascii' ), \
60
- bytes (f"--include-path={ libcudacxx_path } " , encoding = 'ascii' ), \
61
- bytes (f"--include-path={ cuda_include_path } " , encoding = 'ascii' ), \
58
+ opts = [b"--std=c++17" ,
59
+ bytes (f"--include-path={ cub_path } " , encoding = 'ascii' ),
60
+ bytes (f"--include-path={ thrust_path } " , encoding = 'ascii' ),
61
+ bytes (f"--include-path={ libcudacxx_path } " , encoding = 'ascii' ),
62
+ bytes (f"--include-path={ cuda_include_path } " , encoding = 'ascii' ),
62
63
bytes (f"--gpu-architecture=compute_{ cc } " , encoding = 'ascii' )]
63
64
if rdc :
64
65
opts += [b"--relocatable-device-code=true" ]
@@ -70,7 +71,8 @@ def compile_impl(cpp, cc, rdc, code, nvrtc_path, nvrtc_version):
70
71
opts += [b"-DCCCL_DISABLE_BF16_SUPPORT" ]
71
72
72
73
# Create program
73
- err , prog = nvrtc .nvrtcCreateProgram (str .encode (cpp ), b"code.cu" , 0 , [], [])
74
+ err , prog = nvrtc .nvrtcCreateProgram (
75
+ str .encode (cpp ), b"code.cu" , 0 , [], [])
74
76
if err != nvrtc .nvrtcResult .NVRTC_SUCCESS :
75
77
raise RuntimeError (f"nvrtcCreateProgram error: { err } " )
76
78
@@ -103,12 +105,13 @@ def compile_impl(cpp, cc, rdc, code, nvrtc_path, nvrtc_version):
103
105
104
106
return ptx .decode ('ascii' )
105
107
108
+
106
109
def compile (** kwargs ):
107
110
108
111
err , major , minor = nvrtc .nvrtcVersion ()
109
112
if err != nvrtc .nvrtcResult .NVRTC_SUCCESS :
110
113
raise RuntimeError (f"nvrtcVersion error: { err } " )
111
114
nvrtc_version = version (major , minor )
112
- return nvrtc_version , compile_impl (** kwargs , \
113
- nvrtc_path = nvrtc .__file__ , \
114
- nvrtc_version = nvrtc_version )
115
+ return nvrtc_version , compile_impl (** kwargs ,
116
+ nvrtc_path = nvrtc .__file__ ,
117
+ nvrtc_version = nvrtc_version )
0 commit comments