Refactoring

This commit is contained in:
André R. Brodtkorb
2018-10-31 10:45:48 +01:00
parent e434b4e02a
commit 71777dad4e
9 changed files with 136 additions and 84 deletions

View File

@@ -170,8 +170,9 @@ class CudaContext(object):
"""
def get_prepared_kernel(self, kernel_filename, kernel_function_name, \
prepared_call_args, \
include_dirs=[], no_extern_c=True,
**kwargs):
include_dirs=[], \
defines={}, \
compile_args={'no_extern_c', True}, jit_compile_args={}):
"""
Helper function to print compilation output
"""
@@ -183,19 +184,20 @@ class CudaContext(object):
self.logger.debug("Error: %s", error_str)
kernel_filename = os.path.normpath(kernel_filename)
kernel_path = os.path.abspath(os.path.join(self.module_path, kernel_filename))
#self.logger.debug("Getting %s", kernel_filename)
# Create a hash of the kernel (and its includes)
kwargs_hasher = hashlib.md5()
kwargs_hasher.update(str(kwargs).encode('utf-8'));
kwargs_hash = kwargs_hasher.hexdigest()
kwargs_hasher = None
options_hasher = hashlib.md5()
options_hasher.update(str(defines).encode('utf-8') + str(compile_args).encode('utf-8'));
options_hash = options_hasher.hexdigest()
options_hasher = None
root, ext = os.path.splitext(kernel_filename)
kernel_hash = root \
+ "_" + CudaContext.hash_kernel( \
os.path.join(self.module_path, kernel_filename), \
kernel_path, \
include_dirs=[self.module_path] + include_dirs) \
+ "_" + kwargs_hash \
+ "_" + options_hash \
+ ext
cached_kernel_filename = os.path.join(self.cache_path, kernel_hash)
@@ -210,7 +212,7 @@ class CudaContext(object):
with io.open(cached_kernel_filename, "rb") as file:
file_str = file.read()
module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler)
module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler, **jit_compile_args)
kernel = module.get_function(kernel_function_name)
kernel.prepare(prepared_call_args)
@@ -223,7 +225,7 @@ class CudaContext(object):
#Create kernel string
kernel_string = ""
for key, value in kwargs.items():
for key, value in defines.items():
kernel_string += "#define {:s} {:s}\n".format(str(key), str(value))
kernel_string += '#include "{:s}"'.format(os.path.join(self.module_path, kernel_filename))
if (self.use_cache):
@@ -235,8 +237,11 @@ class CudaContext(object):
with Common.Timer("compiler") as timer:
cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, no_extern_c=no_extern_c, cache_dir=False)
module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler)
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The CUDA compiler succeeded, but said the following:\nkernel.cu", category=UserWarning)
cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args)
module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler, **jit_compile_args)
if (self.use_cache):
with io.open(cached_kernel_filename, "wb") as file:
file.write(cubin)