feat(array): improve checking the array for NaNs

This commit is contained in:
Anthony Berg 2025-09-03 18:08:45 +02:00
parent 26c0eab7c8
commit 87474dcb20

View File

@ -1,4 +1,5 @@
import ctypes
from typing import Union
import numpy as np
from hip import hip, hipblas
@ -13,12 +14,28 @@ class HIPArakawaA2D(BaseArakawaA2D):
A class representing an Arakawa A type (unstaggered, logically Cartesian) grid
"""
def __init__(self, stream, nx, ny, halo_x, halo_y, cpu_variables):
def __init__(self, stream: hip.ihipStream_t, nx: int, ny: int, halo_x: int, halo_y: int, cpu_variables: list[Union[np.ndarray, None]]):
"""
Uploads initial data to the GPU device
"""
super().__init__(stream, nx, ny, halo_x, halo_y, cpu_variables, HIPArray2D)
# Variables for ``__sum_array``
# TODO should have a way of not hardcoding the dtype
dtype = np.float32
self.__result_h = np.zeros(1, dtype=dtype)
self.__num_bytes = self.__result_h.itemsize
self.__result_d = hip_check(hip.hipMalloc(self.__num_bytes))
self.__total_sum_d = hip_check(hip.hipMalloc(self.__num_bytes))
self.__handle = hip_check(hipblas.hipblasCreate())
def __del__(self):
# Cleanup GPU variables in ``__sum_array``
hip_check(hipblas.hipblasDestroy(self.__handle))
hip_check(hip.hipFree(self.__result_d))
hip_check(hip.hipFree(self.__total_sum_d))
def check(self):
"""
Checks that data is still sane
@ -31,8 +48,7 @@ class HIPArakawaA2D(BaseArakawaA2D):
if np.isnan(var_sum):
raise ValueError("Data contains NaN values!")
@staticmethod
def __sum_array(array: HIPArray2D) -> np.ndarray[tuple[int]]:
def __sum_array(self, array: HIPArray2D) -> np.ndarray[tuple[int]]:
"""
Sum all the elements in HIPArray2D using hipblas.
Args:
@ -40,35 +56,22 @@ class HIPArakawaA2D(BaseArakawaA2D):
Returns:
The sum of all the elements in ``array``.
"""
dtype = array.dtype
result_h = np.zeros(1, dtype=dtype)
num_bytes = dtype.itemsize
result_d = hip_check(hip.hipMalloc(num_bytes))
# Sum the ``data_h`` array using hipblas
handle = hip_check(hipblas.hipblasCreate())
# Using pitched memory, so we need to sum row by row
total_sum_d = hip_check(hip.hipMalloc(num_bytes))
hip_check(hip.hipMemset(total_sum_d, 0, num_bytes))
hip_check(hip.hipMemset(self.__total_sum_d, 0, self.__num_bytes))
width, height = array.shape
for y in range(height):
row_ptr = int(array.data) + y * array.pitch_d
hip_check(hipblas.hipblasSasum(handle, width, row_ptr, 1, result_d))
hip_check(hipblas.hipblasSasum(self.__handle, width, row_ptr, 1, self.__result_d))
hip_check(hipblas.hipblasSaxpy(handle, 1, ctypes.c_float(1.0), result_d, 1, total_sum_d, 1))
hip_check(hip.hipMemcpy(result_h, total_sum_d, num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
hip_check(
hipblas.hipblasSaxpy(self.__handle, 1, ctypes.c_float(1.0), self.__result_d, 1, self.__total_sum_d, 1))
# Copy over the result from the device
hip_check(hip.hipMemcpy(result_h, total_sum_d, num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
hip_check(hip.hipMemcpy(self.__result_h, self.__total_sum_d, self.__num_bytes,
hip.hipMemcpyKind.hipMemcpyDeviceToHost))
# Cleanup
hip_check(hipblas.hipblasDestroy(handle))
hip_check(hip.hipFree(result_d))
hip_check(hip.hipFree(total_sum_d))
return result_h
return self.__result_h