context with hip are deprecated

This commit is contained in:
Hicham Agueny 2024-02-27 17:22:07 +01:00
parent 163eb02a0b
commit ef32115e48

View File

@ -24,7 +24,8 @@ import gc
from IPython.core import magic_arguments from IPython.core import magic_arguments
from IPython.core.magic import line_magic, Magics, magics_class from IPython.core.magic import line_magic, Magics, magics_class
import pycuda.driver as cuda #import pycuda.driver as cuda
from hip import hip, hiprtc
from GPUSimulators import Common, CudaContext from GPUSimulators import Common, CudaContext
@ -41,6 +42,20 @@ class MagicCudaContext(Magics):
'--no_cache', '-nc', action="store_true", help='Disable caching of kernels') '--no_cache', '-nc', action="store_true", help='Disable caching of kernels')
@magic_arguments.argument( @magic_arguments.argument(
'--no_autotuning', '-na', action="store_true", help='Disable autotuning of kernels') '--no_autotuning', '-na', action="store_true", help='Disable autotuning of kernels')
def hip_check(call_result):
err = call_result[0]
result = call_result[1:]
if len(result) == 1:
result = result[0]
if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
raise RuntimeError(str(err))
elif (
isinstance(err, hiprtc.hiprtcResult)
and err != hiprtc.hiprtcResult.HIPRTC_SUCCESS
):
raise RuntimeError(str(err))
return result
def cuda_context_handler(self, line): def cuda_context_handler(self, line):
args = magic_arguments.parse_argstring(self.cuda_context_handler, line) args = magic_arguments.parse_argstring(self.cuda_context_handler, line)
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
@ -49,7 +64,8 @@ class MagicCudaContext(Magics):
context_flags = None context_flags = None
if (args.blocking): if (args.blocking):
context_flags = cuda.ctx_flags.SCHED_BLOCKING_SYNC #context_flags = cuda.ctx_flags.SCHED_BLOCKING_SYNC
context_flags = hip_check(hip.hipSetDeviceFlags(hip.hipDeviceScheduleBlockingSync))
if args.name in self.shell.user_ns.keys(): if args.name in self.shell.user_ns.keys():
self.logger.debug("Context already registered! Ignoring") self.logger.debug("Context already registered! Ignoring")
@ -63,11 +79,13 @@ class MagicCudaContext(Magics):
# this function will be called on exceptions in any cell # this function will be called on exceptions in any cell
def custom_exc(shell, etype, evalue, tb, tb_offset=None): def custom_exc(shell, etype, evalue, tb, tb_offset=None):
self.logger.exception("Exception caught: Resetting to CUDA context %s", args.name) self.logger.exception("Exception caught: Resetting to CUDA context %s", args.name)
while (cuda.Context.get_current() != None): #while (cuda.Context.get_current() != None):
context = cuda.Context.get_current() while (hip.hipCtxGetCurrent() != None):
#context = cuda.Context.get_current()
context = hip_check(hip.hipCtxGetCurrent())
self.logger.info("Popping <%s>", str(context.handle)) self.logger.info("Popping <%s>", str(context.handle))
cuda.Context.pop() #cuda.Context.pop()
hip.hipCtxPopCurrent()
if args.name in self.shell.user_ns.keys(): if args.name in self.shell.user_ns.keys():
self.logger.info("Pushing <%s>", str(self.shell.user_ns[args.name].cuda_context.handle)) self.logger.info("Pushing <%s>", str(self.shell.user_ns[args.name].cuda_context.handle))
self.shell.user_ns[args.name].cuda_context.push() self.shell.user_ns[args.name].cuda_context.push()
@ -88,10 +106,13 @@ class MagicCudaContext(Magics):
import atexit import atexit
def exitfunc(): def exitfunc():
self.logger.info("Exitfunc: Resetting CUDA context stack") self.logger.info("Exitfunc: Resetting CUDA context stack")
while (cuda.Context.get_current() != None): #while (cuda.Context.get_current() != None):
context = cuda.Context.get_current() while (hip.hipCtxGetCurrent() != None):
#context = cuda.Context.get_current()
context = hip_check(hip.hipCtxGetCurrent())
self.logger.info("`-> Popping <%s>", str(context.handle)) self.logger.info("`-> Popping <%s>", str(context.handle))
cuda.Context.pop() #cuda.Context.pop()
hip.hipCtxPopCurrent()
self.logger.debug("==================================================================") self.logger.debug("==================================================================")
atexit.register(exitfunc) atexit.register(exitfunc)