fix(gpu): correct syntax for dictionary typing

This commit is contained in:
Anthony Berg 2025-09-03 18:04:43 +02:00
parent e92e7188f8
commit 97c2fd47e3
4 changed files with 29 additions and 24 deletions

View File

@ -64,8 +64,8 @@ class HIPContext(Context):
def get_module(self, kernel_filename: str,
function: str,
include_dirs: list[str] = None,
defines: dict[str: int] = None,
compile_args: dict[str: list] = None,
defines: dict[str, int] = None,
compile_args: dict[str, list] = None,
jit_compile_args: dict = None):
"""
Reads a ``.hip`` file and creates a HIP kernel from that.

View File

@ -21,19 +21,22 @@ class HIPHandler(BaseGPUHandler):
self.num_bytes = self.cfl_data_h.size * self.cfl_data_h.itemsize
self.cfl_data = hip_check(hip.hipMalloc(self.num_bytes)).configure(
typestr=np.finfo(self.dtype).dtype.name, shape=grid_size
typestr=self.cfl_data_h.dtype.str, shape=grid_size
)
def __del__(self):
hip_check(hip.hipFree(self.cfl_data))
def prepared_call(self, grid_size, block_size, stream, args):
if len(grid_size) < 3:
grid_size = (*grid_size, 1)
def prepared_call(self, grid_size: tuple[int, int], block_size: tuple[int, int, int], stream: hip.ihipStream_t,
args: list):
grid = hip.dim3(*grid_size)
block = hip.dim3(*block_size)
for i in range(len(args)):
val = args[i]
if isinstance(val, int) or isinstance(val, np.int32):
if isinstance(val, np.int64):
args[i] = ctypes.c_int64(val)
elif isinstance(val, int) or isinstance(val, np.int32):
args[i] = ctypes.c_int(val)
elif isinstance(val, float) or isinstance(val, np.float32):
args[i] = ctypes.c_float(val)
@ -42,29 +45,31 @@ class HIPHandler(BaseGPUHandler):
hip_check(hip.hipModuleLaunchKernel(
self.kernel,
*grid_size,
*block_size,
0,
stream,
None,
args
*grid,
*block,
sharedMemBytes=0,
stream=stream,
kernelParams=None,
extra=args
))
def array_fill(self, data, stream):
def array_fill(self, data: float, stream: hip.ihipStream_t):
self.cfl_data_h.fill(data)
hip_check(
hip.hipMemcpyAsync(self.cfl_data, self.cfl_data_h, self.num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice,
stream))
def array_min(self, stream):
def array_min(self, stream: hip.ihipStream_t) -> float:
handle = hip_check(hipblas.hipblasCreate())
value = np.empty(1, self.dtype)
hip_check(hipblas.hipblasIsamin(handle, self.cfl_data.size, self.cfl_data, 1, value))
value_h = np.empty(1, self.dtype)
value_d = hip_check(hip.hipMalloc(value_h.itemsize))
hip_check(hipblas.hipblasIsamin(handle, self.cfl_data.size, self.cfl_data, 1, value_d))
hip_check(hipblas.hipblasDestroy(handle))
hip_check(hip.hipMemcpy(value, self.cfl_data, self.cfl_data_h.itemsize, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
hip_check(
hip.hipMemcpy(value_h, self.cfl_data, self.cfl_data_h.itemsize, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
return value[0]

View File

@ -80,16 +80,16 @@ class BaseMPISimulator(BaseSimulator):
})
gi, gj = grid.get_coordinate()
# print("gi: " + str(gi) + ", gj: " + str(gj))
if gi == 0 and boundary_conditions.west != BoundaryCondition.Type.Periodic:
if (gi == 0 and boundary_conditions.west != BoundaryCondition.Type.Periodic):
self.west = None
new_boundary_conditions.west = boundary_conditions.west
if gj == 0 and boundary_conditions.south != BoundaryCondition.Type.Periodic:
if (gj == 0 and boundary_conditions.south != BoundaryCondition.Type.Periodic):
self.south = None
new_boundary_conditions.south = boundary_conditions.south
if gi == grid.x - 1 and boundary_conditions.east != BoundaryCondition.Type.Periodic:
if (gi == grid.x - 1 and boundary_conditions.east != BoundaryCondition.Type.Periodic):
self.east = None
new_boundary_conditions.east = boundary_conditions.east
if gj == grid.y - 1 and boundary_conditions.north != BoundaryCondition.Type.Periodic:
if (gj == grid.y - 1 and boundary_conditions.north != BoundaryCondition.Type.Periodic):
self.north = None
new_boundary_conditions.north = boundary_conditions.north
sim.set_boundary_conditions(new_boundary_conditions)

View File

@ -42,7 +42,7 @@ class BoundaryCondition(object):
Periodic = 2,
Reflective = 3
def __init__(self, types: dict[str: Type.Reflective]=None):
def __init__(self, types: dict[str, Type]=None):
"""
Constructor
"""