update cudaContext

This commit is contained in:
Hicham Agueny 2024-02-27 18:44:18 +01:00
parent ef32115e48
commit 3fd27d5bb8

View File

@ -36,11 +36,15 @@ import gc
#import pycuda.driver as cuda #import pycuda.driver as cuda
from hip import hip,hiprtc from hip import hip,hiprtc
from hip import rccl
from GPUSimulators import Autotuner, Common from GPUSimulators import Autotuner, Common
"""
Class which keeps track of the CUDA context and some helper functions
"""
class CudaContext(object):
def hip_check(call_result): def hip_check(call_result):
err = call_result[0] err = call_result[0]
result = call_result[1:] result = call_result[1:]
@ -55,11 +59,6 @@ def hip_check(call_result):
raise RuntimeError(str(err)) raise RuntimeError(str(err))
return result return result
"""
Class which keeps track of the CUDA context and some helper functions
"""
class CudaContext(object):
def __init__(self, device=None, context_flags=None, use_cache=True, autotuning=True): def __init__(self, device=None, context_flags=None, use_cache=True, autotuning=True):
""" """
Create a new CUDA context Create a new CUDA context
@ -92,18 +91,18 @@ class CudaContext(object):
self.logger.debug(" => compute capability: %s", str(self.hip.hipDeviceComputeCapability(device))) self.logger.debug(" => compute capability: %s", str(self.hip.hipDeviceComputeCapability(device)))
# Create the CUDA context # Create the CUDA context
#In HIP there is no need to specify a scheduling policy (it is abstracted). Here the HIP runtime system manages the workload to fit a specifc target architecture if context_flags is None:
#if context_flags is None:
# context_flags=cuda.ctx_flags.SCHED_AUTO # context_flags=cuda.ctx_flags.SCHED_AUTO
context_flags=hip_check(hip.hipSetDeviceFlags(hip.hipDeviceScheduleAuto))
#self.cuda_context = self.cuda_device.make_context(flags=context_flags) #self.cuda_context = self.cuda_device.make_context(flags=context_flags)
self.cuda_context = self.hip_check(hip.hipCtxCreate(0, device))
#free, total = cuda.mem_get_info() #free, total = cuda.mem_get_info()
total = hip_check(hip.hipDeviceTotalMem(device)) total = hip_check(hip.hipDeviceTotalMem(device))
#self.logger.debug(" => memory: %d / %d MB available", int(free/(1024*1024)), int(total/(1024*1024))) #self.logger.debug(" => memory: %d / %d MB available", int(free/(1024*1024)), int(total/(1024*1024)))
self.logger.debug(" => memory: %d / %d MB available", int(total/(1024*1024))) self.logger.debug(" => memory: %d / %d MB available", int(total/(1024*1024)))
#self.logger.info("Created context handle <%s>", str(self.cuda_context.handle)) self.logger.info("Created context handle <%s>", str(self.cuda_context.handle))
#Create cache dir for cubin files #Create cache dir for cubin files
self.cache_path = os.path.join(self.module_path, "cuda_cache") self.cache_path = os.path.join(self.module_path, "cuda_cache")
@ -118,32 +117,37 @@ class CudaContext(object):
self.autotuner = Autotuner.Autotuner() self.autotuner = Autotuner.Autotuner()
# def __del__(self, *args): def __del__(self, *args):
# self.logger.info("Cleaning up CUDA context handle <%s>", str(self.cuda_context.handle)) self.logger.info("Cleaning up CUDA context handle <%s>", str(self.cuda_context.handle))
# Loop over all contexts in stack, and remove "this" # Loop over all contexts in stack, and remove "this"
# other_contexts = [] other_contexts = []
#while (cuda.Context.get_current() != None): #while (cuda.Context.get_current() != None):
while (hip.hipCtxGetCurrent() != None):
#context = cuda.Context.get_current() #context = cuda.Context.get_current()
# if (context.handle != self.cuda_context.handle): context = hip_check(hip.hipCtxGetCurrent())
# self.logger.debug("<%s> Popping <%s> (*not* ours)", str(self.cuda_context.handle), str(context.handle)) if (context.handle != self.cuda_context.handle):
# other_contexts = [context] + other_contexts self.logger.debug("<%s> Popping <%s> (*not* ours)", str(self.cuda_context.handle), str(context.handle))
other_contexts = [context] + other_contexts
#cuda.Context.pop() #cuda.Context.pop()
# else: hip.hipCtxPopCurrent()
# self.logger.debug("<%s> Popping <%s> (ours)", str(self.cuda_context.handle), str(context.handle)) else:
self.logger.debug("<%s> Popping <%s> (ours)", str(self.cuda_context.handle), str(context.handle))
#cuda.Context.pop() #cuda.Context.pop()
hip.hipCtxPopCurrent()
# Add all the contexts we popped that were not our own # Add all the contexts we popped that were not our own
# for context in other_contexts: for context in other_contexts:
# self.logger.debug("<%s> Pushing <%s>", str(self.cuda_context.handle), str(context.handle)) self.logger.debug("<%s> Pushing <%s>", str(self.cuda_context.handle), str(context.handle))
#cuda.Context.push(context) #cuda.Context.push(context)
hip_check(hip.hipCtxPushCurrent(context))
# self.logger.debug("<%s> Detaching", str(self.cuda_context.handle)) self.logger.debug("<%s> Detaching", str(self.cuda_context.handle))
# self.cuda_context.detach() self.cuda_context.detach()
# def __str__(self): def __str__(self):
# return "CudaContext id " + str(self.cuda_context.handle) return "CudaContext id " + str(self.cuda_context.handle)
def hash_kernel(kernel_filename, include_dirs): def hash_kernel(kernel_filename, include_dirs):
@ -244,9 +248,9 @@ class CudaContext(object):
with io.open(cached_kernel_filename, "rb") as file: with io.open(cached_kernel_filename, "rb") as file:
file_str = file.read() file_str = file.read()
#No hip counterpart of module_from_buffer #module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler, **jit_compile_args)
module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler, **jit_compile_args) module = hip_check(hip.hipModuleLoadDataEx(file_str, 0, None))
print("HIP module loaded: to be checked!")
self.modules[kernel_hash] = module self.modules[kernel_hash] = module
return module return module
@ -272,21 +276,10 @@ class CudaContext(object):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The CUDA compiler succeeded, but said the following:\nkernel.cu", category=UserWarning) warnings.filterwarnings("ignore", message="The CUDA compiler succeeded, but said the following:\nkernel.cu", category=UserWarning)
#cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args) cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args)
#module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler, **jit_compile_args) #module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler, **jit_compile_args)
#HIP version of compilation: but "name_of_fct" needs to be defined. e.g. #cubin = hip_check(hiprtc.hiprtcCreateProgram(kernel_string.encode(), b"Kernel-Name", 0, [], []))
#source = b"""\
#extern "C" __global__ void name_of_fct(float factor, int n, short unused1, int unused2, float unused3, float *x) {
#int tid = threadIdx.x + blockIdx.x * blockDim.x;
#if (tid < n) {
#x[tid] *= factor;
# }
#}
#"""
prog = hip_check(hiprtc.hiprtcCreateProgram(kernel_string, b"name_of_fct", 0, [], []))
props = hip.hipDeviceProp_t() props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props,0)) hip_check(hip.hipGetDeviceProperties(props,0))
arch = props.gcnArchName arch = props.gcnArchName
@ -294,17 +287,16 @@ class CudaContext(object):
print(f"Compiling kernel for {arch}") print(f"Compiling kernel for {arch}")
cflags = [b"--offload-arch="+arch] cflags = [b"--offload-arch="+arch]
err, = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags) err, = hiprtc.hiprtcCompileProgram(cubin, len(cflags), cflags)
if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS: if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS:
log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog)) log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(cubin))
log = bytearray(log_size) log = bytearray(log_size)
hip_check(hiprtc.hiprtcGetProgramLog(prog, log)) hip_check(hiprtc.hiprtcGetProgramLog(cubin, log))
raise RuntimeError(log.decode()) raise RuntimeError(log.decode())
code_size = hip_check(hiprtc.hiprtcGetCodeSize(prog)) code_size = hip_check(hiprtc.hiprtcGetCodeSize(cubin))
code = bytearray(code_size) code = bytearray(code_size)
hip_check(hiprtc.hiprtcGetCode(prog, code)) hip_check(hiprtc.hiprtcGetCode(cubin, code))
module = hip_check(hip.hipModuleLoadData(code)) module = hip_check(hip.hipModuleLoadData(code))
#kernel = hip_check(hip.hipModuleGetFunction(module, b"name_of_fct"))
if (self.use_cache): if (self.use_cache):
with io.open(cached_kernel_filename, "wb") as file: with io.open(cached_kernel_filename, "wb") as file: