feat(array): improve checking the array for NaNs

This commit is contained in:
Anthony Berg 2025-09-03 18:08:45 +02:00
parent 26c0eab7c8
commit 87474dcb20

View File

@ -1,4 +1,5 @@
import ctypes import ctypes
from typing import Union
import numpy as np import numpy as np
from hip import hip, hipblas from hip import hip, hipblas
@ -13,12 +14,28 @@ class HIPArakawaA2D(BaseArakawaA2D):
A class representing an Arakawa A type (unstaggered, logically Cartesian) grid A class representing an Arakawa A type (unstaggered, logically Cartesian) grid
""" """
def __init__(self, stream, nx, ny, halo_x, halo_y, cpu_variables): def __init__(self, stream: hip.ihipStream_t, nx: int, ny: int, halo_x: int, halo_y: int, cpu_variables: list[Union[np.ndarray, None]]):
""" """
Uploads initial data to the GPU device Uploads initial data to the GPU device
""" """
super().__init__(stream, nx, ny, halo_x, halo_y, cpu_variables, HIPArray2D) super().__init__(stream, nx, ny, halo_x, halo_y, cpu_variables, HIPArray2D)
# Variables for ``__sum_array``
# TODO should have a way of not hardcoding the dtype
dtype = np.float32
self.__result_h = np.zeros(1, dtype=dtype)
self.__num_bytes = self.__result_h.itemsize
self.__result_d = hip_check(hip.hipMalloc(self.__num_bytes))
self.__total_sum_d = hip_check(hip.hipMalloc(self.__num_bytes))
self.__handle = hip_check(hipblas.hipblasCreate())
def __del__(self):
# Cleanup GPU variables in ``__sum_array``
hip_check(hipblas.hipblasDestroy(self.__handle))
hip_check(hip.hipFree(self.__result_d))
hip_check(hip.hipFree(self.__total_sum_d))
def check(self): def check(self):
""" """
Checks that data is still sane Checks that data is still sane
@ -31,8 +48,7 @@ class HIPArakawaA2D(BaseArakawaA2D):
if np.isnan(var_sum): if np.isnan(var_sum):
raise ValueError("Data contains NaN values!") raise ValueError("Data contains NaN values!")
@staticmethod def __sum_array(self, array: HIPArray2D) -> np.ndarray[tuple[int]]:
def __sum_array(array: HIPArray2D) -> np.ndarray[tuple[int]]:
""" """
Sum all the elements in HIPArray2D using hipblas. Sum all the elements in HIPArray2D using hipblas.
Args: Args:
@ -40,35 +56,22 @@ class HIPArakawaA2D(BaseArakawaA2D):
Returns: Returns:
The sum of all the elements in ``array``. The sum of all the elements in ``array``.
""" """
dtype = array.dtype
result_h = np.zeros(1, dtype=dtype)
num_bytes = dtype.itemsize
result_d = hip_check(hip.hipMalloc(num_bytes))
# Sum the ``data_h`` array using hipblas
handle = hip_check(hipblas.hipblasCreate())
# Using pitched memory, so we need to sum row by row # Using pitched memory, so we need to sum row by row
total_sum_d = hip_check(hip.hipMalloc(num_bytes)) hip_check(hip.hipMemset(self.__total_sum_d, 0, self.__num_bytes))
hip_check(hip.hipMemset(total_sum_d, 0, num_bytes))
width, height = array.shape width, height = array.shape
for y in range(height): for y in range(height):
row_ptr = int(array.data) + y * array.pitch_d row_ptr = int(array.data) + y * array.pitch_d
hip_check(hipblas.hipblasSasum(handle, width, row_ptr, 1, result_d)) hip_check(hipblas.hipblasSasum(self.__handle, width, row_ptr, 1, self.__result_d))
hip_check(hipblas.hipblasSaxpy(handle, 1, ctypes.c_float(1.0), result_d, 1, total_sum_d, 1)) hip_check(
hipblas.hipblasSaxpy(self.__handle, 1, ctypes.c_float(1.0), self.__result_d, 1, self.__total_sum_d, 1))
hip_check(hip.hipMemcpy(result_h, total_sum_d, num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
# Copy over the result from the device # Copy over the result from the device
hip_check(hip.hipMemcpy(result_h, total_sum_d, num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost)) hip_check(hip.hipMemcpy(self.__result_h, self.__total_sum_d, self.__num_bytes,
hip.hipMemcpyKind.hipMemcpyDeviceToHost))
# Cleanup return self.__result_h
hip_check(hipblas.hipblasDestroy(handle))
hip_check(hip.hipFree(result_d))
hip_check(hip.hipFree(total_sum_d))
return result_h