mirror of
https://github.com/smyalygames/FiniteVolumeGPU.git
synced 2025-09-14 19:22:17 +02:00
feat(array): improve checking the array for NaNs
This commit is contained in:
parent
26c0eab7c8
commit
87474dcb20
@ -1,4 +1,5 @@
|
||||
import ctypes
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from hip import hip, hipblas
|
||||
@ -13,12 +14,28 @@ class HIPArakawaA2D(BaseArakawaA2D):
|
||||
A class representing an Arakawa A type (unstaggered, logically Cartesian) grid
|
||||
"""
|
||||
|
||||
def __init__(self, stream, nx, ny, halo_x, halo_y, cpu_variables):
|
||||
def __init__(self, stream: hip.ihipStream_t, nx: int, ny: int, halo_x: int, halo_y: int, cpu_variables: list[Union[np.ndarray, None]]):
|
||||
"""
|
||||
Uploads initial data to the GPU device
|
||||
"""
|
||||
super().__init__(stream, nx, ny, halo_x, halo_y, cpu_variables, HIPArray2D)
|
||||
|
||||
# Variables for ``__sum_array``
|
||||
# TODO should have a way of not hardcoding the dtype
|
||||
dtype = np.float32
|
||||
self.__result_h = np.zeros(1, dtype=dtype)
|
||||
self.__num_bytes = self.__result_h.itemsize
|
||||
self.__result_d = hip_check(hip.hipMalloc(self.__num_bytes))
|
||||
self.__total_sum_d = hip_check(hip.hipMalloc(self.__num_bytes))
|
||||
|
||||
self.__handle = hip_check(hipblas.hipblasCreate())
|
||||
|
||||
def __del__(self):
|
||||
# Cleanup GPU variables in ``__sum_array``
|
||||
hip_check(hipblas.hipblasDestroy(self.__handle))
|
||||
hip_check(hip.hipFree(self.__result_d))
|
||||
hip_check(hip.hipFree(self.__total_sum_d))
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
Checks that data is still sane
|
||||
@ -31,8 +48,7 @@ class HIPArakawaA2D(BaseArakawaA2D):
|
||||
if np.isnan(var_sum):
|
||||
raise ValueError("Data contains NaN values!")
|
||||
|
||||
@staticmethod
|
||||
def __sum_array(array: HIPArray2D) -> np.ndarray[tuple[int]]:
|
||||
def __sum_array(self, array: HIPArray2D) -> np.ndarray[tuple[int]]:
|
||||
"""
|
||||
Sum all the elements in HIPArray2D using hipblas.
|
||||
Args:
|
||||
@ -40,35 +56,22 @@ class HIPArakawaA2D(BaseArakawaA2D):
|
||||
Returns:
|
||||
The sum of all the elements in ``array``.
|
||||
"""
|
||||
dtype = array.dtype
|
||||
result_h = np.zeros(1, dtype=dtype)
|
||||
num_bytes = dtype.itemsize
|
||||
result_d = hip_check(hip.hipMalloc(num_bytes))
|
||||
|
||||
# Sum the ``data_h`` array using hipblas
|
||||
handle = hip_check(hipblas.hipblasCreate())
|
||||
|
||||
# Using pitched memory, so we need to sum row by row
|
||||
total_sum_d = hip_check(hip.hipMalloc(num_bytes))
|
||||
hip_check(hip.hipMemset(total_sum_d, 0, num_bytes))
|
||||
hip_check(hip.hipMemset(self.__total_sum_d, 0, self.__num_bytes))
|
||||
|
||||
width, height = array.shape
|
||||
|
||||
for y in range(height):
|
||||
row_ptr = int(array.data) + y * array.pitch_d
|
||||
|
||||
hip_check(hipblas.hipblasSasum(handle, width, row_ptr, 1, result_d))
|
||||
hip_check(hipblas.hipblasSasum(self.__handle, width, row_ptr, 1, self.__result_d))
|
||||
|
||||
hip_check(hipblas.hipblasSaxpy(handle, 1, ctypes.c_float(1.0), result_d, 1, total_sum_d, 1))
|
||||
|
||||
hip_check(hip.hipMemcpy(result_h, total_sum_d, num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
|
||||
hip_check(
|
||||
hipblas.hipblasSaxpy(self.__handle, 1, ctypes.c_float(1.0), self.__result_d, 1, self.__total_sum_d, 1))
|
||||
|
||||
# Copy over the result from the device
|
||||
hip_check(hip.hipMemcpy(result_h, total_sum_d, num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
|
||||
hip_check(hip.hipMemcpy(self.__result_h, self.__total_sum_d, self.__num_bytes,
|
||||
hip.hipMemcpyKind.hipMemcpyDeviceToHost))
|
||||
|
||||
# Cleanup
|
||||
hip_check(hipblas.hipblasDestroy(handle))
|
||||
hip_check(hip.hipFree(result_d))
|
||||
hip_check(hip.hipFree(total_sum_d))
|
||||
|
||||
return result_h
|
||||
return self.__result_h
|
||||
|
Loading…
x
Reference in New Issue
Block a user