Fixed order again

This commit is contained in:
André R. Brodtkorb
2018-11-15 16:47:13 +01:00
parent dcb849b705
commit 7592ad5b9f
22 changed files with 758 additions and 619 deletions

View File

@@ -48,7 +48,7 @@ class CudaContext(object):
self.blocking = blocking
self.use_cache = use_cache
self.logger = logging.getLogger(__name__)
self.kernels = {}
self.modules = {}
self.module_path = os.path.dirname(os.path.realpath(__file__))
@@ -164,12 +164,12 @@ class CudaContext(object):
break
return kernel_hasher.hexdigest()
"""
Reads a text file and creates an OpenCL kernel from that
"""
def get_prepared_kernel(self, kernel_filename, kernel_function_name, \
prepared_call_args, \
def get_module(self, kernel_filename,
include_dirs=[], \
defines={}, \
compile_args={'no_extern_c', True}, jit_compile_args={}):
@@ -206,9 +206,9 @@ class CudaContext(object):
cached_kernel_filename = os.path.join(self.cache_path, kernel_hash)
# If we have the kernel in our hashmap, return it
if (kernel_hash in self.kernels.keys()):
if (kernel_hash in self.modules.keys()):
self.logger.debug("Found kernel %s cached in hashmap (%s)", kernel_filename, kernel_hash)
return self.kernels[kernel_hash]
return self.modules[kernel_hash]
# If we have it on disk, return it
elif (self.use_cache and os.path.isfile(cached_kernel_filename)):
@@ -218,10 +218,8 @@ class CudaContext(object):
file_str = file.read()
module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler, **jit_compile_args)
kernel = module.get_function(kernel_function_name)
kernel.prepare(prepared_call_args)
self.kernels[kernel_hash] = kernel
return kernel
self.modules[kernel_hash] = module
return module
# Otherwise, compile it from source
else:
@@ -250,19 +248,15 @@ class CudaContext(object):
with io.open(cached_kernel_filename, "wb") as file:
file.write(cubin)
kernel = module.get_function(kernel_function_name)
kernel.prepare(prepared_call_args)
self.kernels[kernel_hash] = kernel
return kernel
self.modules[kernel_hash] = module
return module
"""
Clears the kernel cache (useful for debugging & development)
"""
def clear_kernel_cache(self):
self.logger.debug("Clearing cache")
self.kernels = {}
self.modules = {}
gc.collect()
"""