add hip_check()

This commit is contained in:
Hicham Agueny 2024-02-26 11:01:45 +01:00 committed by GitHub
parent 7573668e53
commit 7a8214e3f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -43,6 +43,20 @@ import json
from hip import hip, hiprtc from hip import hip, hiprtc
from hip import hipblas from hip import hipblas
def hip_check(call_result):
err = call_result[0]
result = call_result[1:]
if len(result) == 1:
result = result[0]
if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
raise RuntimeError(str(err))
elif (
isinstance(err, hiprtc.hiprtcResult)
and err != hiprtc.hiprtcResult.HIPRTC_SUCCESS
):
raise RuntimeError(str(err))
return result
def safeCall(cmd): def safeCall(cmd):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try: