fix(gpu): escape sequence in regex

This commit is contained in:
Anthony Berg 2025-06-29 23:13:07 +02:00
parent 3aedef93cf
commit 0f0329cf36
2 changed files with 4 additions and 4 deletions

View File

@ -21,7 +21,7 @@ class Context(object):
self.logger = logging.getLogger(__name__)
self.modules = {}
self.module_path = os.path.dirname(os.path.realpath(__file__))
self.module_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + language)
self.autotuner = None

View File

@ -164,7 +164,7 @@ class CudaContext(object):
kernel_hasher.update(str(modified).encode('utf-8'))
# Find all includes
includes = re.findall('^\W*#include\W+(.+?)\W*$', file_str, re.M)
includes = re.findall('^\\W*#include\\W+(.+?)\\W*$', file_str, re.M)
# Loop over everything that looks like an 'include'
for include_file in includes:
@ -204,7 +204,7 @@ class CudaContext(object):
if defines is None:
defines = {}
if include_dirs is None:
include_dirs = []
include_dirs = [os.path.join(self.module_path) + "include"]
if compile_args is None:
compile_args = {'no_extern_c': True}
if jit_compile_args is None:
@ -221,7 +221,7 @@ class CudaContext(object):
if error_str:
self.logger.debug(f"Error: {error_str}")
kernel_filename = os.path.normpath("cuda/" + kernel_filename + ".cu")
kernel_filename = os.path.normpath(kernel_filename + ".cu")
kernel_path = os.path.abspath(os.path.join(self.module_path, kernel_filename))
# self.logger.debug("Getting %s", kernel_filename)