add hip-based compilation

This commit is contained in:
Hicham Agueny 2024-01-17 17:17:48 +01:00
parent 74b24c33d7
commit 28ab54a86e

View File

@ -41,12 +41,6 @@ from hip import rccl
from GPUSimulators import Autotuner, Common
"""
Class which keeps track of the CUDA context and some helper functions
"""
class CudaContext(object):
def hip_check(call_result):
err = call_result[0]
result = call_result[1:]
@ -61,6 +55,11 @@ class CudaContext(object):
raise RuntimeError(str(err))
return result
"""
Class which keeps track of the CUDA context and some helper functions
"""
class CudaContext(object):
def __init__(self, device=None, context_flags=None, use_cache=True, autotuning=True):
"""
Create a new CUDA context
@ -245,6 +244,7 @@ class CudaContext(object):
with io.open(cached_kernel_filename, "rb") as file:
file_str = file.read()
#No hip counterpart of module_from_buffer
module = cuda.module_from_buffer(file_str, message_handler=cuda_compile_message_handler, **jit_compile_args)
self.modules[kernel_hash] = module
@ -271,8 +271,41 @@ class CudaContext(object):
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The CUDA compiler succeeded, but said the following:\nkernel.cu", category=UserWarning)
cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args)
module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler, **jit_compile_args)
#cubin = cuda_compiler.compile(kernel_string, include_dirs=include_dirs, cache_dir=False, **compile_args)
#module = cuda.module_from_buffer(cubin, message_handler=cuda_compile_message_handler, **jit_compile_args)
#HIP version of compilation: but "name_of_fct" needs to be defined. e.g.
#source = b"""\
#extern "C" __global__ void name_of_fct(float factor, int n, short unused1, int unused2, float unused3, float *x) {
#int tid = threadIdx.x + blockIdx.x * blockDim.x;
#if (tid < n) {
#x[tid] *= factor;
# }
#}
#"""
prog = hip_check(hiprtc.hiprtcCreateProgram(kernel_string, b"name_of_fct", 0, [], []))
props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props,0))
arch = props.gcnArchName
print(f"Compiling kernel for {arch}")
cflags = [b"--offload-arch="+arch]
err, = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags)
if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS:
log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog))
log = bytearray(log_size)
hip_check(hiprtc.hiprtcGetProgramLog(prog, log))
raise RuntimeError(log.decode())
code_size = hip_check(hiprtc.hiprtcGetCodeSize(prog))
code = bytearray(code_size)
hip_check(hiprtc.hiprtcGetCode(prog, code))
module = hip_check(hip.hipModuleLoadData(code))
#kernel = hip_check(hip.hipModuleGetFunction(module, b"name_of_fct"))
if (self.use_cache):
with io.open(cached_kernel_filename, "wb") as file:
file.write(cubin)