mirror of
https://github.com/smyalygames/FiniteVolumeGPU.git
synced 2026-01-14 15:48:43 +01:00
Refactoring
This commit is contained in:
@@ -170,8 +170,9 @@ class CudaContext(object):
|
||||
"""
|
||||
def get_prepared_kernel(self, kernel_filename, kernel_function_name, \
|
||||
prepared_call_args, \
|
||||
include_dirs=[], no_extern_c=True,
|
||||
**kwargs):
|
||||
include_dirs=[], \
|
||||
defines={}, \
|
||||
compile_args={'no_extern_c', True}, jit_compile_args={}):
|
||||
"""
|
||||
Helper function to print compilation output
|
||||
"""
|
||||
@@ -183,19 +184,20 @@ class CudaContext(object):
|
||||
self.logger.debug("Error: %s", error_str)
|
||||
|
||||
kernel_filename = os.path.normpath(kernel_filename)
|
||||
kernel_path = os.path.abspath(os.path.join(self.module_path, kernel_filename))
|
||||
#self.logger.debug("Getting %s", kernel_filename)
|
||||
|
||||
# Create a hash of the kernel (and its includes)
|
||||
kwargs_hasher = hashlib.md5()
|
||||
kwargs_hasher.update(str(kwargs).encode('utf-8'));
|
||||
kwargs_hash = kwargs_hasher.hexdigest()
|
||||
kwargs_hasher = None
|
||||
options_hasher = hashlib.md5()
|
||||
options_hasher.update(str(defines).encode('utf-8') + str(compile_args).encode('utf-8'));
|
||||
options_hash = options_hasher.hexdigest()
|
||||
options_hasher = None
|
||||
root, ext = os.path.splitext(kernel_filename)
|
||||
kernel_hash = root \
|
||||
+ "_" + CudaContext.hash_kernel( \
|
||||
os.path.join(self.module_path, kernel_filename), \
|
||||
kernel_path, \
|
||||
include_dirs=[self.module_path] + include_dirs) \
|
||||
+ "_" + kwargs_hash \
|
||||
+ "_" + options_hash \
|
||||
+ ext
|
||||
cached_kernel_filename = os.path.join(self.cache_path, kernel_hash)
|
||||
|
||||
@@ -210,7 +212,7 @@ class CudaContext(object):
|
||||
|
||||
with io.open(cached_kernel_filename, "rb") as file:
|
||||
file_str = file.read()
|
||||
module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler)
|
||||
module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler, **jit_compile_args)
|
||||
|
||||
kernel = module.get_function(kernel_function_name)
|
||||
kernel.prepare(prepared_call_args)
|
||||
@@ -223,7 +225,7 @@ class CudaContext(object):
|
||||
|
||||
#Create kernel string
|
||||
kernel_string = ""
|
||||
for key, value in kwargs.items():
|
||||
for key, value in defines.items():
|
||||
kernel_string += "#define {:s} {:s}\n".format(str(key), str(value))
|
||||
kernel_string += '#include "{:s}"'.format(os.path.join(self.module_path, kernel_filename))
|
||||
if (self.use_cache):
|
||||
@@ -235,8 +237,11 @@ class CudaContext(object):
|
||||
|
||||
|
||||
with Common.Timer("compiler") as timer:
|
||||
cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, no_extern_c=no_extern_c, cache_dir=False)
|
||||
module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler)
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message="The CUDA compiler succeeded, but said the following:\nkernel.cu", category=UserWarning)
|
||||
cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args)
|
||||
module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler, **jit_compile_args)
|
||||
if (self.use_cache):
|
||||
with io.open(cached_kernel_filename, "wb") as file:
|
||||
file.write(cubin)
|
||||
|
||||
Reference in New Issue
Block a user