import logging import os import io import logging from hashlib import md5 from GPUSimulators.common.utils import get_project_root, get_includes class Context(object): """ Class that manages either a HIP or CUDA context. """ def __init__(self, language: str, device=0, context_flags=None, use_cache=True, autotuning=True): """ Create a new context. """ self.use_cache = use_cache self.logger = logging.getLogger(__name__) self.modules = {} self.module_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), f"{language}") self.autotuner = None # Creates cache directory if specified self.cache_path = os.path.join(get_project_root(), ".fvm_cache", type(self).__name__.lower()) if self.use_cache: if not os.path.isdir(self.cache_path): os.makedirs(self.cache_path) self.logger.info(f"Using cache dir {self.cache_path}") def __del__(self): """ Cleans up the context. """ pass def __str__(self): """ Gives the context id. """ pass def hash_kernel(self, kernel_filename: str, include_dirs: list[str]) -> str: """ Generate a kernel ID for the caches. Args: kernel_filename: Path to the kernel file. include_dirs: Directories to search for ``#include`` in the kernel file. Returns: MD5 has for the kernel in the cache. Raises: RuntimeError: When the number of ``#include``s surpassed the maximum (101) permitted ``#include``s. """ num_includes = 0 max_includes = 100 kernel_hasher = md5() logger = logging.getLogger(__name__) # Loop over files and includes, and check if something has changed files = [kernel_filename] while len(files): if num_includes > max_includes: raise RuntimeError(f"Maximum number of includes reached.\n" + f"Potential circular include in {kernel_filename}?") filename = files.pop() modified = os.path.getmtime(filename) # Open the file with io.open(filename, "r") as file: # Search for ``#include `` and also hash the file file_str = file.read() kernel_hasher.update(file_str.encode('utf-8')) kernel_hasher.update(str(modified).encode('utf-8')) # Find all the includes includes = get_includes(file_str) # Iterate through everything that looks like is an ``include`` for include_file in includes: # Search through ``include`` directories for the file file_path = os.path.dirname(filename) for include_path in [file_path] + include_dirs: # If found, add it to the list of files to check temp_path = os.path.join(include_path, include_file) if os.path.isfile(temp_path): files = files + [temp_path] # To avoid circular includes num_includes = num_includes + 1 break return kernel_hasher.hexdigest() def get_module(self, kernel_filename: str, function: str, include_dirs: dict = None, defines: list[str] = None, compile_args: dict = None, jit_compile_args: dict = None): """ Reads a text file and creates a kernel from that. """ raise NotImplementedError("Needs to be implemented in subclass") def synchronize(self): """ Synchronizes all the streams, etc. """ raise NotImplementedError("Needs to be implemented in subclass")