diff --git a/GPUSimulators/common/arrays/hip/arkawa2d.py b/GPUSimulators/common/arrays/hip/arkawa2d.py index 67aaf6c..f1e53c2 100644 --- a/GPUSimulators/common/arrays/hip/arkawa2d.py +++ b/GPUSimulators/common/arrays/hip/arkawa2d.py @@ -1,4 +1,5 @@ import ctypes +from typing import Union import numpy as np from hip import hip, hipblas @@ -13,12 +14,28 @@ class HIPArakawaA2D(BaseArakawaA2D): A class representing an Arakawa A type (unstaggered, logically Cartesian) grid """ - def __init__(self, stream, nx, ny, halo_x, halo_y, cpu_variables): + def __init__(self, stream: hip.ihipStream_t, nx: int, ny: int, halo_x: int, halo_y: int, cpu_variables: list[Union[np.ndarray, None]]): """ Uploads initial data to the GPU device """ super().__init__(stream, nx, ny, halo_x, halo_y, cpu_variables, HIPArray2D) + # Variables for ``__sum_array`` + # TODO should have a way of not hardcoding the dtype + dtype = np.float32 + self.__result_h = np.zeros(1, dtype=dtype) + self.__num_bytes = self.__result_h.itemsize + self.__result_d = hip_check(hip.hipMalloc(self.__num_bytes)) + self.__total_sum_d = hip_check(hip.hipMalloc(self.__num_bytes)) + + self.__handle = hip_check(hipblas.hipblasCreate()) + + def __del__(self): + # Cleanup GPU variables in ``__sum_array`` + hip_check(hipblas.hipblasDestroy(self.__handle)) + hip_check(hip.hipFree(self.__result_d)) + hip_check(hip.hipFree(self.__total_sum_d)) + def check(self): """ Checks that data is still sane @@ -31,8 +48,7 @@ class HIPArakawaA2D(BaseArakawaA2D): if np.isnan(var_sum): raise ValueError("Data contains NaN values!") - @staticmethod - def __sum_array(array: HIPArray2D) -> np.ndarray[tuple[int]]: + def __sum_array(self, array: HIPArray2D) -> np.ndarray[tuple[int]]: """ Sum all the elements in HIPArray2D using hipblas. Args: @@ -40,35 +56,22 @@ class HIPArakawaA2D(BaseArakawaA2D): Returns: The sum of all the elements in ``array``. """ - dtype = array.dtype - result_h = np.zeros(1, dtype=dtype) - num_bytes = dtype.itemsize - result_d = hip_check(hip.hipMalloc(num_bytes)) - - # Sum the ``data_h`` array using hipblas - handle = hip_check(hipblas.hipblasCreate()) # Using pitched memory, so we need to sum row by row - total_sum_d = hip_check(hip.hipMalloc(num_bytes)) - hip_check(hip.hipMemset(total_sum_d, 0, num_bytes)) + hip_check(hip.hipMemset(self.__total_sum_d, 0, self.__num_bytes)) width, height = array.shape for y in range(height): row_ptr = int(array.data) + y * array.pitch_d - hip_check(hipblas.hipblasSasum(handle, width, row_ptr, 1, result_d)) + hip_check(hipblas.hipblasSasum(self.__handle, width, row_ptr, 1, self.__result_d)) - hip_check(hipblas.hipblasSaxpy(handle, 1, ctypes.c_float(1.0), result_d, 1, total_sum_d, 1)) - - hip_check(hip.hipMemcpy(result_h, total_sum_d, num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost)) + hip_check( + hipblas.hipblasSaxpy(self.__handle, 1, ctypes.c_float(1.0), self.__result_d, 1, self.__total_sum_d, 1)) # Copy over the result from the device - hip_check(hip.hipMemcpy(result_h, total_sum_d, num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost)) + hip_check(hip.hipMemcpy(self.__result_h, self.__total_sum_d, self.__num_bytes, + hip.hipMemcpyKind.hipMemcpyDeviceToHost)) - # Cleanup - hip_check(hipblas.hipblasDestroy(handle)) - hip_check(hip.hipFree(result_d)) - hip_check(hip.hipFree(total_sum_d)) - - return result_h + return self.__result_h