From a9c3a5155618aafd3e03319c043e99c834bd55de Mon Sep 17 00:00:00 2001 From: Anthony Berg Date: Thu, 3 Jul 2025 13:15:36 +0200 Subject: [PATCH] feat(gpu): add assigning device for HIP and make a string for context --- GPUSimulators/gpu/hip_context.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/GPUSimulators/gpu/hip_context.py b/GPUSimulators/gpu/hip_context.py index 56b0829..0dfb1bc 100644 --- a/GPUSimulators/gpu/hip_context.py +++ b/GPUSimulators/gpu/hip_context.py @@ -14,7 +14,7 @@ class HIPContext(Context): Class that manages the HIP context. """ - def __init__(self, device=0, context_flags=None, use_cache=True, autotuning=False): + def __init__(self, device=None, context_flags=None, use_cache=True, autotuning=False): """ Creates a new HIP context. """ @@ -25,6 +25,11 @@ class HIPContext(Context): self.logger.info(f"HIP Python version {hip_main.HIP_VERSION_NAME}") self.logger.info(f"ROCm version {hip_main.ROCM_VERSION_NAME}") + if device is None: + device = 0 + + hip_check(hip.hipSetDevice(device)) + # Device information props = hip.hipDeviceProp_t() hip_check(hip.hipGetDeviceProperties(props, device)) @@ -49,6 +54,10 @@ class HIPContext(Context): for prog in self.prog.values(): hip_check(hiprtc.hiprtcDestroyProgram(prog.createRef())) + def __str__(self): + device_handle = hip_check(hip.hipGetDevice()) + return f"HIPContext id {device_handle}" + def get_module(self, kernel_filename: str, function: str, include_dirs: list[str] = None,