diff --git a/GPUSimulators/CudaContext.py b/GPUSimulators/CudaContext.py index b4a2490..e77ef06 100644 --- a/GPUSimulators/CudaContext.py +++ b/GPUSimulators/CudaContext.py @@ -41,13 +41,7 @@ from hip import rccl from GPUSimulators import Autotuner, Common - -""" -Class which keeps track of the CUDA context and some helper functions -""" -class CudaContext(object): - - def hip_check(call_result): +def hip_check(call_result): err = call_result[0] result = call_result[1:] if len(result) == 1: @@ -61,6 +55,11 @@ class CudaContext(object): raise RuntimeError(str(err)) return result +""" +Class which keeps track of the CUDA context and some helper functions +""" +class CudaContext(object): + def __init__(self, device=None, context_flags=None, use_cache=True, autotuning=True): """ Create a new CUDA context @@ -245,6 +244,7 @@ class CudaContext(object): with io.open(cached_kernel_filename, "rb") as file: file_str = file.read() + #No hip counterpart of module_from_buffer module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler, **jit_compile_args) self.modules[kernel_hash] = module @@ -271,8 +271,41 @@ class CudaContext(object): import warnings with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="The CUDA compiler succeeded, but said the following:\nkernel.cu", category=UserWarning) - cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args) - module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler, **jit_compile_args) + + #cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args) + #module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler, **jit_compile_args) + + #HIP version of compilation: but "name_of_fct" needs to be defined. e.g. + #source = b"""\ + #extern "C" __global__ void name_of_fct(float factor, int n, short unused1, int unused2, float unused3, float *x) { + #int tid = threadIdx.x + blockIdx.x * blockDim.x; + #if (tid < n) { + #x[tid] *= factor; + # } + #} + #""" + + prog = hip_check(hiprtc.hiprtcCreateProgram(kernel_string, b"name_of_fct", 0, [], [])) + + props = hip.hipDeviceProp_t() + hip_check(hip.hipGetDeviceProperties(props,0)) + arch = props.gcnArchName + + print(f"Compiling kernel for {arch}") + + cflags = [b"--offload-arch="+arch] + err, = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags) + if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS: + log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog)) + log = bytearray(log_size) + hip_check(hiprtc.hiprtcGetProgramLog(prog, log)) + raise RuntimeError(log.decode()) + code_size = hip_check(hiprtc.hiprtcGetCodeSize(prog)) + code = bytearray(code_size) + hip_check(hiprtc.hiprtcGetCode(prog, code)) + module = hip_check(hip.hipModuleLoadData(code)) + #kernel = hip_check(hip.hipModuleGetFunction(module, b"name_of_fct")) + if (self.use_cache): with io.open(cached_kernel_filename, "wb") as file: file.write(cubin)