mirror of
https://github.com/smyalygames/FiniteVolumeGPU_HIP.git
synced 2025-05-18 06:24:11 +02:00
context with hip are deprecated
This commit is contained in:
parent
163eb02a0b
commit
ef32115e48
@ -24,7 +24,8 @@ import gc
|
||||
|
||||
from IPython.core import magic_arguments
|
||||
from IPython.core.magic import line_magic, Magics, magics_class
|
||||
import pycuda.driver as cuda
|
||||
#import pycuda.driver as cuda
|
||||
from hip import hip, hiprtc
|
||||
|
||||
from GPUSimulators import Common, CudaContext
|
||||
|
||||
@ -41,6 +42,20 @@ class MagicCudaContext(Magics):
|
||||
'--no_cache', '-nc', action="store_true", help='Disable caching of kernels')
|
||||
@magic_arguments.argument(
|
||||
'--no_autotuning', '-na', action="store_true", help='Disable autotuning of kernels')
|
||||
def hip_check(call_result):
|
||||
err = call_result[0]
|
||||
result = call_result[1:]
|
||||
if len(result) == 1:
|
||||
result = result[0]
|
||||
if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
|
||||
raise RuntimeError(str(err))
|
||||
elif (
|
||||
isinstance(err, hiprtc.hiprtcResult)
|
||||
and err != hiprtc.hiprtcResult.HIPRTC_SUCCESS
|
||||
):
|
||||
raise RuntimeError(str(err))
|
||||
return result
|
||||
|
||||
def cuda_context_handler(self, line):
|
||||
args = magic_arguments.parse_argstring(self.cuda_context_handler, line)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
@ -49,7 +64,8 @@ class MagicCudaContext(Magics):
|
||||
|
||||
context_flags = None
|
||||
if (args.blocking):
|
||||
context_flags = cuda.ctx_flags.SCHED_BLOCKING_SYNC
|
||||
#context_flags = cuda.ctx_flags.SCHED_BLOCKING_SYNC
|
||||
context_flags = hip_check(hip.hipSetDeviceFlags(hip.hipDeviceScheduleBlockingSync))
|
||||
|
||||
if args.name in self.shell.user_ns.keys():
|
||||
self.logger.debug("Context already registered! Ignoring")
|
||||
@ -63,11 +79,13 @@ class MagicCudaContext(Magics):
|
||||
# this function will be called on exceptions in any cell
|
||||
def custom_exc(shell, etype, evalue, tb, tb_offset=None):
|
||||
self.logger.exception("Exception caught: Resetting to CUDA context %s", args.name)
|
||||
while (cuda.Context.get_current() != None):
|
||||
context = cuda.Context.get_current()
|
||||
#while (cuda.Context.get_current() != None):
|
||||
while (hip.hipCtxGetCurrent() != None):
|
||||
#context = cuda.Context.get_current()
|
||||
context = hip_check(hip.hipCtxGetCurrent())
|
||||
self.logger.info("Popping <%s>", str(context.handle))
|
||||
cuda.Context.pop()
|
||||
|
||||
#cuda.Context.pop()
|
||||
hip.hipCtxPopCurrent()
|
||||
if args.name in self.shell.user_ns.keys():
|
||||
self.logger.info("Pushing <%s>", str(self.shell.user_ns[args.name].cuda_context.handle))
|
||||
self.shell.user_ns[args.name].cuda_context.push()
|
||||
@ -88,10 +106,13 @@ class MagicCudaContext(Magics):
|
||||
import atexit
|
||||
def exitfunc():
|
||||
self.logger.info("Exitfunc: Resetting CUDA context stack")
|
||||
while (cuda.Context.get_current() != None):
|
||||
context = cuda.Context.get_current()
|
||||
#while (cuda.Context.get_current() != None):
|
||||
while (hip.hipCtxGetCurrent() != None):
|
||||
#context = cuda.Context.get_current()
|
||||
context = hip_check(hip.hipCtxGetCurrent())
|
||||
self.logger.info("`-> Popping <%s>", str(context.handle))
|
||||
cuda.Context.pop()
|
||||
#cuda.Context.pop()
|
||||
hip.hipCtxPopCurrent()
|
||||
self.logger.debug("==================================================================")
|
||||
atexit.register(exitfunc)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user