mirror of
https://github.com/smyalygames/FiniteVolumeGPU.git
synced 2026-01-14 15:48:43 +01:00
Refactoring
This commit is contained in:
@@ -135,58 +135,14 @@ class MPIGrid(object):
|
||||
grid = np.flip(np.sort(grid))
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def getExtent(self, width, height, rank):
|
||||
"""
|
||||
Function which returns the extent of node with rank
|
||||
rank in the grid
|
||||
"""
|
||||
i, j = self.getCoordinate(rank)
|
||||
x0 = i * width
|
||||
y0 = j * height
|
||||
x1 = x0+width
|
||||
y1 = y0+height
|
||||
return [x0, x1, y0, y1]
|
||||
|
||||
|
||||
def gatherData(self, data, rank=0):
|
||||
"""
|
||||
Function which gathers the data onto node with rank
|
||||
rank
|
||||
"""
|
||||
#Get shape of data
|
||||
ny, nx = data.shape
|
||||
|
||||
#Create list of buffers to return
|
||||
retval = []
|
||||
|
||||
#If we are the target node, recieve from others
|
||||
#otherwise send to target
|
||||
if (self.comm.rank == rank):
|
||||
mpi_requests = []
|
||||
retval = []
|
||||
|
||||
#Loop over all nodes
|
||||
for k in range(0, self.comm.size):
|
||||
#If k equal target node, add our own data
|
||||
#Otherwise receive it from node k
|
||||
if (k == rank):
|
||||
retval += [data]
|
||||
else:
|
||||
buffer = np.empty((ny, nx), dtype=np.float32)
|
||||
retval += [buffer]
|
||||
mpi_requests += [self.comm.Irecv(buffer, source=k, tag=k)]
|
||||
|
||||
#Wait for transfers to complete
|
||||
for mpi_request in mpi_requests:
|
||||
mpi_request.wait()
|
||||
else:
|
||||
mpi_request = self.comm.Isend(data, dest=rank, tag=self.comm.rank)
|
||||
mpi_request.wait()
|
||||
|
||||
return retval
|
||||
|
||||
def gather(self, data, root=0):
|
||||
out_data = None
|
||||
if (self.comm.rank == root):
|
||||
out_data = np.empty([self.comm.size] + list(data.shape), dtype=data.dtype)
|
||||
self.comm.Gather(data, out_data, root)
|
||||
return out_data
|
||||
|
||||
|
||||
class MPISimulator(Simulator.BaseSimulator):
|
||||
@@ -259,8 +215,8 @@ class MPISimulator(Simulator.BaseSimulator):
|
||||
self.exchange()
|
||||
self.sim.substep(dt, step_number)
|
||||
|
||||
def download(self):
|
||||
return self.sim.download()
|
||||
def getOutput(self):
|
||||
return self.sim.getOutput()
|
||||
|
||||
def synchronize(self):
|
||||
self.sim.synchronize()
|
||||
@@ -273,7 +229,22 @@ class MPISimulator(Simulator.BaseSimulator):
|
||||
global_dt = np.empty(1, dtype=np.float32)
|
||||
self.grid.comm.Allreduce(local_dt, global_dt, op=MPI.MIN)
|
||||
self.logger.debug("Local dt: {:f}, global dt: {:f}".format(local_dt[0], global_dt[0]))
|
||||
return global_dt[0]
|
||||
return global_dt[0]
|
||||
|
||||
|
||||
def getExtent(self):
|
||||
"""
|
||||
Function which returns the extent of node with rank
|
||||
rank in the grid
|
||||
"""
|
||||
width = self.sim.nx*self.sim.dx
|
||||
height = self.sim.ny*self.sim.dy
|
||||
i, j = self.grid.getCoordinate()
|
||||
x0 = i * width
|
||||
y0 = j * height
|
||||
x1 = x0 + width
|
||||
y1 = y0 + height
|
||||
return [x0, x1, y0, y1]
|
||||
|
||||
def exchange(self):
|
||||
####
|
||||
|
||||
Reference in New Issue
Block a user