Refactoring CudaArray and ArakawaA grid

This commit is contained in:
André R. Brodtkorb
2018-08-23 20:44:49 +02:00
parent 5668e28f99
commit 918d22b257
10 changed files with 452 additions and 159 deletions

View File

@@ -26,7 +26,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
#Import packages we need
import numpy as np
from GPUSimulators import Simulator
from GPUSimulators import Simulator, Common
@@ -60,9 +60,7 @@ class KP07_dimsplit (Simulator.BaseSimulator):
# Call super constructor
super().__init__(context, \
h0, hu0, hv0, \
nx, ny, \
2, 2, \
dx, dy, dt, \
g, \
block_width, block_height);
@@ -75,6 +73,16 @@ class KP07_dimsplit (Simulator.BaseSimulator):
BLOCK_WIDTH=self.local_size[0], \
BLOCK_HEIGHT=self.local_size[1])
#Create data by uploading to device
self.u0 = Common.ArakawaA2D(self.stream, \
nx, ny, \
2, 2, \
[h0, hu0, hv0])
self.u1 = Common.ArakawaA2D(self.stream, \
nx, ny, \
2, 2, \
[None, None, None])
def __str__(self):
return "Kurganov-Petrova 2007 dimensionally split"
@@ -91,13 +99,13 @@ class KP07_dimsplit (Simulator.BaseSimulator):
self.g, \
self.theta, \
np.int32(0), \
self.data.h0.data.gpudata, self.data.h0.data.strides[0], \
self.data.hu0.data.gpudata, self.data.hu0.data.strides[0], \
self.data.hv0.data.gpudata, self.data.hv0.data.strides[0], \
self.data.h1.data.gpudata, self.data.h1.data.strides[0], \
self.data.hu1.data.gpudata, self.data.hu1.data.strides[0], \
self.data.hv1.data.gpudata, self.data.hv1.data.strides[0])
self.data.swap()
self.u0[0].data.gpudata, self.u0[0].data.strides[0], \
self.u0[1].data.gpudata, self.u0[1].data.strides[0], \
self.u0[2].data.gpudata, self.u0[2].data.strides[0], \
self.u1[0].data.gpudata, self.u1[0].data.strides[0], \
self.u1[1].data.gpudata, self.u1[1].data.strides[0], \
self.u1[2].data.gpudata, self.u1[2].data.strides[0])
self.u0, self.u1 = self.u1, self.u0
self.t += dt
def stepDimsplitYX(self, dt):
@@ -107,13 +115,14 @@ class KP07_dimsplit (Simulator.BaseSimulator):
self.g, \
self.theta, \
np.int32(1), \
self.data.h0.data.gpudata, self.data.h0.data.strides[0], \
self.data.hu0.data.gpudata, self.data.hu0.data.strides[0], \
self.data.hv0.data.gpudata, self.data.hv0.data.strides[0], \
self.data.h1.data.gpudata, self.data.h1.data.strides[0], \
self.data.hu1.data.gpudata, self.data.hu1.data.strides[0], \
self.data.hv1.data.gpudata, self.data.hv1.data.strides[0])
self.data.swap()
self.u0[0].data.gpudata, self.u0[0].data.strides[0], \
self.u0[1].data.gpudata, self.u0[1].data.strides[0], \
self.u0[2].data.gpudata, self.u0[2].data.strides[0], \
self.u1[0].data.gpudata, self.u1[0].data.strides[0], \
self.u1[1].data.gpudata, self.u1[1].data.strides[0], \
self.u1[2].data.gpudata, self.u1[2].data.strides[0])
self.u0, self.u1 = self.u1, self.u0
self.t += dt
def download(self):
return self.u0.download(self.stream)