update cudaContext

This commit is contained in:
Hicham Agueny 2024-02-27 18:44:18 +01:00
parent ef32115e48
commit 3fd27d5bb8

View File

@ -36,12 +36,16 @@ import gc
#import pycuda.driver as cuda #import pycuda.driver as cuda
from hip import hip,hiprtc from hip import hip,hiprtc
from hip import rccl
from GPUSimulators import Autotuner, Common from GPUSimulators import Autotuner, Common
def hip_check(call_result): """
Class which keeps track of the CUDA context and some helper functions
"""
class CudaContext(object):
def hip_check(call_result):
err = call_result[0] err = call_result[0]
result = call_result[1:] result = call_result[1:]
if len(result) == 1: if len(result) == 1:
@ -55,11 +59,6 @@ def hip_check(call_result):
raise RuntimeError(str(err)) raise RuntimeError(str(err))
return result return result
"""
Class which keeps track of the CUDA context and some helper functions
"""
class CudaContext(object):
def __init__(self, device=None, context_flags=None, use_cache=True, autotuning=True): def __init__(self, device=None, context_flags=None, use_cache=True, autotuning=True):
""" """
Create a new CUDA context Create a new CUDA context
@ -92,18 +91,18 @@ class CudaContext(object):
self.logger.debug(" => compute capability: %s", str(self.hip.hipDeviceComputeCapability(device))) self.logger.debug(" => compute capability: %s", str(self.hip.hipDeviceComputeCapability(device)))
# Create the CUDA context # Create the CUDA context
#In HIP there is no need to specify a scheduling policy (it is abstracted). Here the HIP runtime system manages the workload to fit a specifc target architecture if context_flags is None:
#if context_flags is None:
# context_flags=cuda.ctx_flags.SCHED_AUTO # context_flags=cuda.ctx_flags.SCHED_AUTO
context_flags=hip_check(hip.hipSetDeviceFlags(hip.hipDeviceScheduleAuto))
#self.cuda_context = self.cuda_device.make_context(flags=context_flags) #self.cuda_context = self.cuda_device.make_context(flags=context_flags)
self.cuda_context = self.hip_check(hip.hipCtxCreate(0, device))
#free, total = cuda.mem_get_info() #free, total = cuda.mem_get_info()
total = hip_check(hip.hipDeviceTotalMem(device)) total = hip_check(hip.hipDeviceTotalMem(device))
#self.logger.debug(" => memory: %d / %d MB available", int(free/(1024*1024)), int(total/(1024*1024))) #self.logger.debug(" => memory: %d / %d MB available", int(free/(1024*1024)), int(total/(1024*1024)))
self.logger.debug(" => memory: %d / %d MB available", int(total/(1024*1024))) self.logger.debug(" => memory: %d / %d MB available", int(total/(1024*1024)))
#self.logger.info("Created context handle <%s>", str(self.cuda_context.handle)) self.logger.info("Created context handle <%s>", str(self.cuda_context.handle))
#Create cache dir for cubin files #Create cache dir for cubin files
self.cache_path = os.path.join(self.module_path, "cuda_cache") self.cache_path = os.path.join(self.module_path, "cuda_cache")
@ -118,32 +117,37 @@ class CudaContext(object):
self.autotuner = Autotuner.Autotuner() self.autotuner = Autotuner.Autotuner()
# def __del__(self, *args): def __del__(self, *args):
# self.logger.info("Cleaning up CUDA context handle <%s>", str(self.cuda_context.handle)) self.logger.info("Cleaning up CUDA context handle <%s>", str(self.cuda_context.handle))
# Loop over all contexts in stack, and remove "this" # Loop over all contexts in stack, and remove "this"
# other_contexts = [] other_contexts = []
# while (cuda.Context.get_current() != None): #while (cuda.Context.get_current() != None):
# context = cuda.Context.get_current() while (hip.hipCtxGetCurrent() != None):
# if (context.handle != self.cuda_context.handle): #context = cuda.Context.get_current()
# self.logger.debug("<%s> Popping <%s> (*not* ours)", str(self.cuda_context.handle), str(context.handle)) context = hip_check(hip.hipCtxGetCurrent())
# other_contexts = [context] + other_contexts if (context.handle != self.cuda_context.handle):
# cuda.Context.pop() self.logger.debug("<%s> Popping <%s> (*not* ours)", str(self.cuda_context.handle), str(context.handle))
# else: other_contexts = [context] + other_contexts
# self.logger.debug("<%s> Popping <%s> (ours)", str(self.cuda_context.handle), str(context.handle)) #cuda.Context.pop()
# cuda.Context.pop() hip.hipCtxPopCurrent()
else:
self.logger.debug("<%s> Popping <%s> (ours)", str(self.cuda_context.handle), str(context.handle))
#cuda.Context.pop()
hip.hipCtxPopCurrent()
# Add all the contexts we popped that were not our own # Add all the contexts we popped that were not our own
# for context in other_contexts: for context in other_contexts:
# self.logger.debug("<%s> Pushing <%s>", str(self.cuda_context.handle), str(context.handle)) self.logger.debug("<%s> Pushing <%s>", str(self.cuda_context.handle), str(context.handle))
# cuda.Context.push(context) #cuda.Context.push(context)
hip_check(hip.hipCtxPushCurrent(context))
# self.logger.debug("<%s> Detaching", str(self.cuda_context.handle)) self.logger.debug("<%s> Detaching", str(self.cuda_context.handle))
# self.cuda_context.detach() self.cuda_context.detach()
# def __str__(self): def __str__(self):
# return "CudaContext id " + str(self.cuda_context.handle) return "CudaContext id " + str(self.cuda_context.handle)
def hash_kernel(kernel_filename, include_dirs): def hash_kernel(kernel_filename, include_dirs):
@ -244,9 +248,9 @@ class CudaContext(object):
with io.open(cached_kernel_filename, "rb") as file: with io.open(cached_kernel_filename, "rb") as file:
file_str = file.read() file_str = file.read()
#No hip counterpart of module_from_buffer #module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler, **jit_compile_args)
module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler, **jit_compile_args) module = hip_check(hip.hipModuleLoadDataEx(file_str, 0, None))
print("HIP module loaded: to be checked!")
self.modules[kernel_hash] = module self.modules[kernel_hash] = module
return module return module
@ -272,21 +276,10 @@ class CudaContext(object):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The CUDA compiler succeeded, but said the following:\nkernel.cu", category=UserWarning) warnings.filterwarnings("ignore", message="The CUDA compiler succeeded, but said the following:\nkernel.cu", category=UserWarning)
#cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args) cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args)
#module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler, **jit_compile_args) #module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler, **jit_compile_args)
#HIP version of compilation: but "name_of_fct" needs to be defined. e.g. #cubin = hip_check(hiprtc.hiprtcCreateProgram(kernel_string.encode(), b"Kernel-Name", 0, [], []))
#source = b"""\
#extern "C" __global__ void name_of_fct(float factor, int n, short unused1, int unused2, float unused3, float *x) {
#int tid = threadIdx.x + blockIdx.x * blockDim.x;
#if (tid < n) {
#x[tid] *= factor;
# }
#}
#"""
prog = hip_check(hiprtc.hiprtcCreateProgram(kernel_string, b"name_of_fct", 0, [], []))
props = hip.hipDeviceProp_t() props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props,0)) hip_check(hip.hipGetDeviceProperties(props,0))
arch = props.gcnArchName arch = props.gcnArchName
@ -294,17 +287,16 @@ class CudaContext(object):
print(f"Compiling kernel for {arch}") print(f"Compiling kernel for {arch}")
cflags = [b"--offload-arch="+arch] cflags = [b"--offload-arch="+arch]
err, = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags) err, = hiprtc.hiprtcCompileProgram(cubin, len(cflags), cflags)
if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS: if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS:
log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog)) log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(cubin))
log = bytearray(log_size) log = bytearray(log_size)
hip_check(hiprtc.hiprtcGetProgramLog(prog, log)) hip_check(hiprtc.hiprtcGetProgramLog(cubin, log))
raise RuntimeError(log.decode()) raise RuntimeError(log.decode())
code_size = hip_check(hiprtc.hiprtcGetCodeSize(prog)) code_size = hip_check(hiprtc.hiprtcGetCodeSize(cubin))
code = bytearray(code_size) code = bytearray(code_size)
hip_check(hiprtc.hiprtcGetCode(prog, code)) hip_check(hiprtc.hiprtcGetCode(cubin, code))
module = hip_check(hip.hipModuleLoadData(code)) module = hip_check(hip.hipModuleLoadData(code))
#kernel = hip_check(hip.hipModuleGetFunction(module, b"name_of_fct"))
if (self.use_cache): if (self.use_cache):
with io.open(cached_kernel_filename, "wb") as file: with io.open(cached_kernel_filename, "wb") as file: