diff --git a/GPUSimulators/IPythonMagic.py b/GPUSimulators/IPythonMagic.py index fa452df..92baeb8 100644 --- a/GPUSimulators/IPythonMagic.py +++ b/GPUSimulators/IPythonMagic.py @@ -24,7 +24,8 @@ import gc from IPython.core import magic_arguments from IPython.core.magic import line_magic, Magics, magics_class -import pycuda.driver as cuda +#import pycuda.driver as cuda +from hip import hip, hiprtc from GPUSimulators import Common, CudaContext @@ -41,6 +42,20 @@ class MagicCudaContext(Magics): '--no_cache', '-nc', action="store_true", help='Disable caching of kernels') @magic_arguments.argument( '--no_autotuning', '-na', action="store_true", help='Disable autotuning of kernels') + def hip_check(call_result): + err = call_result[0] + result = call_result[1:] + if len(result) == 1: + result = result[0] + if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: + raise RuntimeError(str(err)) + elif ( + isinstance(err, hiprtc.hiprtcResult) + and err != hiprtc.hiprtcResult.HIPRTC_SUCCESS + ): + raise RuntimeError(str(err)) + return result + def cuda_context_handler(self, line): args = magic_arguments.parse_argstring(self.cuda_context_handler, line) self.logger = logging.getLogger(__name__) @@ -49,7 +64,8 @@ class MagicCudaContext(Magics): context_flags = None if (args.blocking): - context_flags = cuda.ctx_flags.SCHED_BLOCKING_SYNC + #context_flags = cuda.ctx_flags.SCHED_BLOCKING_SYNC + context_flags = hip_check(hip.hipSetDeviceFlags(hip.hipDeviceScheduleBlockingSync)) if args.name in self.shell.user_ns.keys(): self.logger.debug("Context already registered! Ignoring") @@ -63,11 +79,13 @@ class MagicCudaContext(Magics): # this function will be called on exceptions in any cell def custom_exc(shell, etype, evalue, tb, tb_offset=None): self.logger.exception("Exception caught: Resetting to CUDA context %s", args.name) - while (cuda.Context.get_current() != None): - context = cuda.Context.get_current() + #while (cuda.Context.get_current() != None): + while (hip.hipCtxGetCurrent() != None): + #context = cuda.Context.get_current() + context = hip_check(hip.hipCtxGetCurrent()) self.logger.info("Popping <%s>", str(context.handle)) - cuda.Context.pop() - + #cuda.Context.pop() + hip.hipCtxPopCurrent() if args.name in self.shell.user_ns.keys(): self.logger.info("Pushing <%s>", str(self.shell.user_ns[args.name].cuda_context.handle)) self.shell.user_ns[args.name].cuda_context.push() @@ -88,10 +106,13 @@ class MagicCudaContext(Magics): import atexit def exitfunc(): self.logger.info("Exitfunc: Resetting CUDA context stack") - while (cuda.Context.get_current() != None): - context = cuda.Context.get_current() + #while (cuda.Context.get_current() != None): + while (hip.hipCtxGetCurrent() != None): + #context = cuda.Context.get_current() + context = hip_check(hip.hipCtxGetCurrent()) self.logger.info("`-> Popping <%s>", str(context.handle)) - cuda.Context.pop() + #cuda.Context.pop() + hip.hipCtxPopCurrent() self.logger.debug("==================================================================") atexit.register(exitfunc)