refactor(common): move sum array method to a static method

This commit is contained in:
Anthony Berg 2025-08-10 11:01:20 +02:00
parent 74b9fe7e66
commit d61e57bf06

View File

@ -8,14 +8,41 @@ from ..arkawa2d import BaseArakawaA2D
from .array2d import HIPArray2D from .array2d import HIPArray2D
def _sum_array(array: HIPArray2D): class HIPArakawaA2D(BaseArakawaA2D):
"""
A class representing an Arakawa A type (unstaggered, logically Cartesian) grid
"""
def __init__(self, stream, nx, ny, halo_x, halo_y, cpu_variables):
"""
Uploads initial data to the GPU device
"""
super().__init__(stream, nx, ny, halo_x, halo_y, cpu_variables, HIPArray2D)
def check(self):
"""
Checks that data is still sane
"""
for i, gpu_variable in enumerate(self.gpu_variables):
var_sum = self.__sum_array(gpu_variable)
self.logger.debug(f"Data {i} with size [{gpu_variable.nx} x {gpu_variable.ny}] "
+ f"has average {var_sum / (gpu_variable.nx * gpu_variable.ny)}")
if np.isnan(var_sum):
raise ValueError("Data contains NaN values!")
@staticmethod
def __sum_array(array: HIPArray2D) -> np.ndarray[tuple[int]]:
""" """
Sum all the elements in HIPArray2D using hipblas. Sum all the elements in HIPArray2D using hipblas.
Args: Args:
array: A HIPArray2D to compute the sum of. array: A HIPArray2D to compute the sum of.
Returns:
The sum of all the elements in ``array``.
""" """
result_h = np.zeros(1, dtype=array.dtype) dtype = array.dtype
num_bytes = result_h.strides[0] result_h = np.zeros(1, dtype=dtype)
num_bytes = dtype.itemsize
result_d = hip_check(hip.hipMalloc(num_bytes)) result_d = hip_check(hip.hipMalloc(num_bytes))
# Sum the ``data_h`` array using hipblas # Sum the ``data_h`` array using hipblas
@ -45,27 +72,3 @@ def _sum_array(array: HIPArray2D):
hip_check(hip.hipFree(total_sum_d)) hip_check(hip.hipFree(total_sum_d))
return result_h return result_h
class HIPArakawaA2D(BaseArakawaA2D):
"""
A class representing an Arakawa A type (unstaggered, logically Cartesian) grid
"""
def __init__(self, stream, nx, ny, halo_x, halo_y, cpu_variables):
"""
Uploads initial data to the GPU device
"""
super().__init__(stream, nx, ny, halo_x, halo_y, cpu_variables, HIPArray2D)
def check(self):
"""
Checks that data is still sane
"""
for i, gpu_variable in enumerate(self.gpu_variables):
var_sum = _sum_array(gpu_variable)
self.logger.debug(f"Data {i} with size [{gpu_variable.nx} x {gpu_variable.ny}] "
+ f"has average {var_sum / (gpu_variable.nx * gpu_variable.ny)}")
if np.isnan(var_sum):
raise ValueError("Data contains NaN values!")