fix(common): cyclical imports for arrays

This commit is contained in:
Anthony Berg 2025-06-25 13:57:56 +02:00
parent f4fff25539
commit 4df7b9b6b7
7 changed files with 16 additions and 16 deletions

View File

@ -1,8 +1,6 @@
# Objects
from GPUSimulators.common.arrays.cuda.arkawa2d import ArakawaA2D
from .arrays import *
from .common import *
from GPUSimulators.common.arrays.cuda.array2d import CudaArray2D
from GPUSimulators.common.arrays.cuda.array3d import CudaArray3D
from .data_dumper import DataDumper
from .ip_engine import IPEngine
from .popen_file_buffer import PopenFileBuffer

View File

@ -6,5 +6,3 @@ if __env_name in environ and environ.get(__env_name).lower() == "cuda":
from .cuda import *
else:
from .hip import *
# TODO this is temporary, remove
from .cuda import array3d

View File

@ -1,6 +1,6 @@
import logging
from GPUSimulators.common.arrays import Array2D
# from .typing import array2d
class BaseArakawaA2D(object):
@ -8,7 +8,7 @@ class BaseArakawaA2D(object):
A base class to be used to represent an Arakawa A type (unstaggered, logically Cartesian) grid.
"""
def __init__(self, stream, nx, ny, halo_x, halo_y, cpu_variables):
def __init__(self, stream, nx, ny, halo_x, halo_y, cpu_variables, array_type):
"""
Uploads initial data to the GPU device
"""
@ -16,7 +16,7 @@ class BaseArakawaA2D(object):
self.gpu_variables = []
for cpu_variable in cpu_variables:
self.gpu_variables += [Array2D(stream, nx, ny, halo_x, halo_y, cpu_variable)]
self.gpu_variables += [array_type(stream, nx, ny, halo_x, halo_y, cpu_variable)]
def __getitem__(self, key):
if type(key) != int:

View File

@ -1,3 +1,3 @@
from arkawa2d import ArakawaA2D
from array2d import CudaArray2D as Array2D
from array3d import CudaArray3D as Array3D
from .arkawa2d import CudaArakawaA2D as ArakawaA2D
from .array2d import CudaArray2D as Array2D
from .array3d import CudaArray3D as Array3D

View File

@ -1,10 +1,11 @@
import numpy as np
import pycuda.gpuarray
from GPUSimulators.common.arrays.arkawa2d import BaseArakawaA2D
from ..arkawa2d import BaseArakawaA2D
from .array2d import CudaArray2D
class ArakawaA2D(BaseArakawaA2D):
class CudaArakawaA2D(BaseArakawaA2D):
"""
A class representing an Arakawa A type (unstaggered, logically Cartesian) grid
"""
@ -13,7 +14,7 @@ class ArakawaA2D(BaseArakawaA2D):
"""
Uploads initial data to the GPU device
"""
super().__init__(stream, nx, ny, halo_x, halo_y, cpu_variables)
super().__init__(stream, nx, ny, halo_x, halo_y, cpu_variables, CudaArray2D)
def check(self):
"""

View File

@ -0,0 +1,3 @@
from .arkawa2d import HIPArakawaA2D as ArakawaA2D
from .array2d import HIPArray2D as Array2D
# from .array3d import HIPArray3D as Array3D

View File

@ -31,7 +31,7 @@ def _sum_array(array: HIPArray2D):
return result_h
class ArakawaA2D(BaseArakawaA2D):
class HIPArakawaA2D(BaseArakawaA2D):
"""
A class representing an Arakawa A type (unstaggered, logically Cartesian) grid
"""
@ -40,7 +40,7 @@ class ArakawaA2D(BaseArakawaA2D):
"""
Uploads initial data to the GPU device
"""
super().__init__(stream, nx, ny, halo_x, halo_y, cpu_variables)
super().__init__(stream, nx, ny, halo_x, halo_y, cpu_variables, HIPArray2D)
def check(self):
"""