2025-07-01 19:43:34 +02:00

41 lines
1.3 KiB
Python

import logging
class BaseArakawaA2D(object):
"""
A base class to be used to represent an Arakawa A type (unstaggered, logically Cartesian) grid.
"""
def __init__(self, stream, nx, ny, halo_x, halo_y, cpu_variables, array_type):
"""
Uploads initial data to the GPU device
"""
self.logger = logging.getLogger(__name__)
self.gpu_variables = []
for cpu_variable in cpu_variables:
self.gpu_variables += [array_type(stream, nx, ny, halo_x, halo_y, cpu_variable)]
def __getitem__(self, key):
if type(key) != int:
raise TypeError("Indexing is int based")
if key > len(self.gpu_variables) or key < 0:
raise IndexError("Out of bounds")
return self.gpu_variables[key]
def download(self, stream, variables=None):
"""
Enables downloading data from the GPU device to Python
"""
if variables is None:
variables = range(len(self.gpu_variables))
cpu_variables = []
for i in variables:
if i >= len (self.gpu_variables):
raise IndexError(f"Variable {i} is out of range")
cpu_variables += [self.gpu_variables[i].download(stream, asynch=True)]
return cpu_variables