update cudaContext

This commit is contained in:
Hicham Agueny 2024-02-27 18:44:18 +01:00
parent ef32115e48
commit 3fd27d5bb8

View File

@ -36,11 +36,15 @@ import gc
#import pycuda.driver as cuda
from hip import hip,hiprtc
from hip import rccl
from GPUSimulators import Autotuner, Common
"""
Class which keeps track of the CUDA context and some helper functions
"""
class CudaContext(object):
def hip_check(call_result):
err = call_result[0]
result = call_result[1:]
@ -55,11 +59,6 @@ def hip_check(call_result):
raise RuntimeError(str(err))
return result
"""
Class which keeps track of the CUDA context and some helper functions
"""
class CudaContext(object):
def __init__(self, device=None, context_flags=None, use_cache=True, autotuning=True):
"""
Create a new CUDA context
@ -92,18 +91,18 @@ class CudaContext(object):
self.logger.debug(" => compute capability: %s", str(self.hip.hipDeviceComputeCapability(device)))
# Create the CUDA context
#In HIP there is no need to specify a scheduling policy (it is abstracted). Here the HIP runtime system manages the workload to fit a specifc target architecture
#if context_flags is None:
if context_flags is None:
# context_flags=cuda.ctx_flags.SCHED_AUTO
context_flags=hip_check(hip.hipSetDeviceFlags(hip.hipDeviceScheduleAuto))
#self.cuda_context = self.cuda_device.make_context(flags=context_flags)
self.cuda_context = self.hip_check(hip.hipCtxCreate(0, device))
#free, total = cuda.mem_get_info()
total = hip_check(hip.hipDeviceTotalMem(device))
#self.logger.debug(" => memory: %d / %d MB available", int(free/(1024*1024)), int(total/(1024*1024)))
self.logger.debug(" => memory: %d / %d MB available", int(total/(1024*1024)))
#self.logger.info("Created context handle <%s>", str(self.cuda_context.handle))
self.logger.info("Created context handle <%s>", str(self.cuda_context.handle))
#Create cache dir for cubin files
self.cache_path = os.path.join(self.module_path, "cuda_cache")
@ -118,32 +117,37 @@ class CudaContext(object):
self.autotuner = Autotuner.Autotuner()
# def __del__(self, *args):
# self.logger.info("Cleaning up CUDA context handle <%s>", str(self.cuda_context.handle))
def __del__(self, *args):
self.logger.info("Cleaning up CUDA context handle <%s>", str(self.cuda_context.handle))
# Loop over all contexts in stack, and remove "this"
# other_contexts = []
other_contexts = []
#while (cuda.Context.get_current() != None):
while (hip.hipCtxGetCurrent() != None):
#context = cuda.Context.get_current()
# if (context.handle != self.cuda_context.handle):
# self.logger.debug("<%s> Popping <%s> (*not* ours)", str(self.cuda_context.handle), str(context.handle))
# other_contexts = [context] + other_contexts
context = hip_check(hip.hipCtxGetCurrent())
if (context.handle != self.cuda_context.handle):
self.logger.debug("<%s> Popping <%s> (*not* ours)", str(self.cuda_context.handle), str(context.handle))
other_contexts = [context] + other_contexts
#cuda.Context.pop()
# else:
# self.logger.debug("<%s> Popping <%s> (ours)", str(self.cuda_context.handle), str(context.handle))
hip.hipCtxPopCurrent()
else:
self.logger.debug("<%s> Popping <%s> (ours)", str(self.cuda_context.handle), str(context.handle))
#cuda.Context.pop()
hip.hipCtxPopCurrent()
# Add all the contexts we popped that were not our own
# for context in other_contexts:
# self.logger.debug("<%s> Pushing <%s>", str(self.cuda_context.handle), str(context.handle))
for context in other_contexts:
self.logger.debug("<%s> Pushing <%s>", str(self.cuda_context.handle), str(context.handle))
#cuda.Context.push(context)
hip_check(hip.hipCtxPushCurrent(context))
# self.logger.debug("<%s> Detaching", str(self.cuda_context.handle))
# self.cuda_context.detach()
self.logger.debug("<%s> Detaching", str(self.cuda_context.handle))
self.cuda_context.detach()
# def __str__(self):
# return "CudaContext id " + str(self.cuda_context.handle)
def __str__(self):
return "CudaContext id " + str(self.cuda_context.handle)
def hash_kernel(kernel_filename, include_dirs):
@ -244,9 +248,9 @@ class CudaContext(object):
with io.open(cached_kernel_filename, "rb") as file:
file_str = file.read()
#No hip counterpart of module_from_buffer
module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler, **jit_compile_args)
#module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler, **jit_compile_args)
module = hip_check(hip.hipModuleLoadDataEx(file_str, 0, None))
print("HIP module loaded: to be checked!")
self.modules[kernel_hash] = module
return module
@ -272,21 +276,10 @@ class CudaContext(object):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The CUDA compiler succeeded, but said the following:\nkernel.cu", category=UserWarning)
#cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args)
cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args)
#module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler, **jit_compile_args)
#HIP version of compilation: but "name_of_fct" needs to be defined. e.g.
#source = b"""\
#extern "C" __global__ void name_of_fct(float factor, int n, short unused1, int unused2, float unused3, float *x) {
#int tid = threadIdx.x + blockIdx.x * blockDim.x;
#if (tid < n) {
#x[tid] *= factor;
# }
#}
#"""
prog = hip_check(hiprtc.hiprtcCreateProgram(kernel_string, b"name_of_fct", 0, [], []))
#cubin = hip_check(hiprtc.hiprtcCreateProgram(kernel_string.encode(), b"Kernel-Name", 0, [], []))
props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props,0))
arch = props.gcnArchName
@ -294,17 +287,16 @@ class CudaContext(object):
print(f"Compiling kernel for {arch}")
cflags = [b"--offload-arch="+arch]
err, = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags)
err, = hiprtc.hiprtcCompileProgram(cubin, len(cflags), cflags)
if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS:
log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog))
log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(cubin))
log = bytearray(log_size)
hip_check(hiprtc.hiprtcGetProgramLog(prog, log))
hip_check(hiprtc.hiprtcGetProgramLog(cubin, log))
raise RuntimeError(log.decode())
code_size = hip_check(hiprtc.hiprtcGetCodeSize(prog))
code_size = hip_check(hiprtc.hiprtcGetCodeSize(cubin))
code = bytearray(code_size)
hip_check(hiprtc.hiprtcGetCode(prog, code))
hip_check(hiprtc.hiprtcGetCode(cubin, code))
module = hip_check(hip.hipModuleLoadData(code))
#kernel = hip_check(hip.hipModuleGetFunction(module, b"name_of_fct"))
if (self.use_cache):
with io.open(cached_kernel_filename, "wb") as file: