mirror of
https://github.com/smyalygames/FiniteVolumeGPU.git
synced 2025-11-27 22:51:47 +01:00
118 lines
3.9 KiB
Python
118 lines
3.9 KiB
Python
import logging
|
|
import os
|
|
import io
|
|
import logging
|
|
from hashlib import md5
|
|
|
|
from GPUSimulators.common.utils import get_project_root, get_includes
|
|
|
|
|
|
class Context(object):
|
|
"""
|
|
Class that manages either a HIP or CUDA context.
|
|
"""
|
|
|
|
def __init__(self, language: str, device=0, context_flags=None, use_cache=True, autotuning=True):
|
|
"""
|
|
Create a new context.
|
|
"""
|
|
self.use_cache = use_cache
|
|
self.logger = logging.getLogger(__name__)
|
|
self.modules = {}
|
|
|
|
self.module_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), f"{language}")
|
|
|
|
self.autotuner = None
|
|
|
|
# Creates cache directory if specified
|
|
self.cache_path = os.path.join(get_project_root(), ".fvm_cache", type(self).__name__.lower())
|
|
if self.use_cache:
|
|
if not os.path.isdir(self.cache_path):
|
|
os.makedirs(self.cache_path)
|
|
self.logger.info(f"Using cache dir {self.cache_path}")
|
|
|
|
def __del__(self):
|
|
"""
|
|
Cleans up the context.
|
|
"""
|
|
pass
|
|
|
|
def __str__(self):
|
|
"""
|
|
Gives the context id.
|
|
"""
|
|
pass
|
|
|
|
def hash_kernel(self, kernel_filename: str, include_dirs: list[str]) -> str:
|
|
"""
|
|
Generate a kernel ID for the caches.
|
|
|
|
Args:
|
|
kernel_filename: Path to the kernel file.
|
|
include_dirs: Directories to search for ``#include`` in the kernel file.
|
|
|
|
Returns:
|
|
MD5 has for the kernel in the cache.
|
|
|
|
Raises:
|
|
RuntimeError: When the number of ``#include``s surpassed the maximum (101) permitted ``#include``s.
|
|
"""
|
|
|
|
num_includes = 0
|
|
max_includes = 100
|
|
kernel_hasher = md5()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Loop over files and includes, and check if something has changed
|
|
files = [kernel_filename]
|
|
while len(files):
|
|
if num_includes > max_includes:
|
|
raise RuntimeError(f"Maximum number of includes reached.\n"
|
|
+ f"Potential circular include in {kernel_filename}?")
|
|
|
|
filename = files.pop()
|
|
|
|
modified = os.path.getmtime(filename)
|
|
|
|
# Open the file
|
|
with io.open(filename, "r") as file:
|
|
# Search for ``#include <reference>`` and also hash the file
|
|
file_str = file.read()
|
|
kernel_hasher.update(file_str.encode('utf-8'))
|
|
kernel_hasher.update(str(modified).encode('utf-8'))
|
|
|
|
# Find all the includes
|
|
includes = get_includes(file_str)
|
|
|
|
# Iterate through everything that looks like is an ``include``
|
|
for include_file in includes:
|
|
# Search through ``include`` directories for the file
|
|
file_path = os.path.dirname(filename)
|
|
for include_path in [file_path] + include_dirs:
|
|
# If found, add it to the list of files to check
|
|
temp_path = os.path.join(include_path, include_file)
|
|
if os.path.isfile(temp_path):
|
|
files = files + [temp_path]
|
|
# To avoid circular includes
|
|
num_includes = num_includes + 1
|
|
break
|
|
|
|
return kernel_hasher.hexdigest()
|
|
|
|
def get_module(self, kernel_filename: str,
|
|
function: str,
|
|
include_dirs: dict = None,
|
|
defines: list[str] = None,
|
|
compile_args: dict = None,
|
|
jit_compile_args: dict = None):
|
|
"""
|
|
Reads a text file and creates a kernel from that.
|
|
"""
|
|
raise NotImplementedError("Needs to be implemented in subclass")
|
|
|
|
def synchronize(self):
|
|
"""
|
|
Synchronizes all the streams, etc.
|
|
"""
|
|
raise NotImplementedError("Needs to be implemented in subclass")
|