feat(gpu): make CudaContext inherit Context

This commit is contained in:
Anthony Berg 2025-06-30 20:36:02 +02:00
parent 1343cfd8c1
commit 86b56741e2

View File

@ -20,11 +20,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import os
import re
import io
import hashlib
import logging
import gc
import pycuda.compiler as cuda_compiler
@ -33,9 +30,10 @@ import pycuda.driver as cuda
from GPUSimulators import Autotuner
from GPUSimulators.common import Timer
from GPUSimulators.gpu.context import Context
class CudaContext(object):
class CudaContext(Context):
"""
Class which keeps track of the CUDA context and some helper functions
"""
@ -49,11 +47,7 @@ class CudaContext(object):
context_flags: To set a blocking context, provide ``cuda.ctx_flags.SCHED_BLOCKING_SYNC``.
"""
self.use_cache = use_cache
self.logger = logging.getLogger(__name__)
self.modules = {}
self.module_path = os.path.dirname(os.path.realpath(__file__))
super().__init__("cuda", device, context_flags, use_cache, autotuning)
# Initialize cuda (must be the first call to PyCUDA)
cuda.init(flags=0)
@ -82,14 +76,6 @@ class CudaContext(object):
self.logger.info(f"Created context handle <{str(self.cuda_context.handle)}>")
# Create cache dir for cubin files
self.cache_path = os.path.join(self.module_path, "cuda_cache")
if self.use_cache:
if not os.path.isdir(self.cache_path):
os.mkdir(self.cache_path)
self.logger.info(f"Using CUDA cache dir {self.cache_path}")
self.autotuner = None
if autotuning:
self.logger.info(
"Autotuning enabled. It may take several minutes to run the code the first time: have patience")
@ -121,76 +107,18 @@ class CudaContext(object):
def __str__(self):
return "CudaContext id " + str(self.cuda_context.handle)
def hash_kernel(self, kernel_filename: str, include_dirs: list[str]) -> str:
"""
Generate a kernel ID for our caches.
Args:
kernel_filename: Path to the kernel file.
include_dirs: Directories to search for ``#include`` in the kernel file.
Returns:
MD5 has for the kernel in the cache.
Raises:
RuntimeError: When the number of ``#include``s surpassed the maximum (101) permitted ``#include``s.
"""
num_includes = 0
max_includes = 100
kernel_hasher = hashlib.md5()
logger = logging.getLogger(__name__)
# Loop over file and includes, and check if something has changed
files = [kernel_filename]
while len(files):
if num_includes > max_includes:
raise RuntimeError(
f"Maximum number of includes reached - circular include in {kernel_filename}?")
filename = files.pop()
# logger.debug("Hashing %s", filename)
modified = os.path.getmtime(filename)
# Open the file
with io.open(filename, "r") as file:
# Search for #inclue <something> and also hash the file
file_str = file.read()
kernel_hasher.update(file_str.encode('utf-8'))
kernel_hasher.update(str(modified).encode('utf-8'))
# Find all includes
includes = re.findall('^\\W*#include\\W+(.+?)\\W*$', file_str, re.M)
# Loop over everything that looks like an 'include'
for include_file in includes:
# Search through 'include' directories for the file
file_path = os.path.dirname(filename)
for include_path in [file_path] + include_dirs:
# If we find it, add it to a list of files to check
temp_path = os.path.join(include_path, include_file)
if os.path.isfile(temp_path):
files = files + [temp_path]
num_includes = num_includes + 1 # For circular includes...
break
return kernel_hasher.hexdigest()
def get_module(self, kernel_filename: str,
function: str,
include_dirs: dict = None,
defines: dict[str: int] = None,
compile_args: dict = None, jit_compile_args: dict = None) -> cuda.Module:
defines: dict[str: dict] = None,
compile_args: dict = None,
jit_compile_args: dict = None) -> cuda.Module:
"""
Reads a text file and creates an OpenCL kernel from that.
Args:
kernel_filename: The file to use for the kernel.
function: The main function of the kernel.
include_dirs: List of directories for the ``#include``s referenced.
defines: Adds ``#define`` tags to the kernel, such as ``#define key value``.
compile_args: Adds other compiler options (parameters) for ``pycuda.compiler.compile()``.
@ -206,7 +134,7 @@ class CudaContext(object):
if include_dirs is None:
include_dirs = [os.path.join(self.module_path), "include"]
if compile_args is None:
compile_args = {'no_extern_c': True}
compile_args = {'cuda': {'no_extern_c': True}}
if jit_compile_args is None:
jit_compile_args = {}
@ -221,6 +149,8 @@ class CudaContext(object):
if error_str:
self.logger.debug(f"Error: {error_str}")
compile_args = compile_args.get('cuda')
kernel_filename = os.path.normpath(kernel_filename + ".cu")
kernel_path = os.path.abspath(os.path.join(self.module_path, kernel_filename))
# self.logger.debug("Getting %s", kernel_filename)