Refactoring

This commit is contained in:
André R. Brodtkorb
2018-12-03 12:09:15 +01:00
parent b266567d09
commit ae6404f05e
11 changed files with 239 additions and 99 deletions

View File

@@ -135,58 +135,14 @@ class MPIGrid(object):
grid = np.flip(np.sort(grid))
return grid
def getExtent(self, width, height, rank):
"""
Function which returns the extent of node with rank
rank in the grid
"""
i, j = self.getCoordinate(rank)
x0 = i * width
y0 = j * height
x1 = x0+width
y1 = y0+height
return [x0, x1, y0, y1]
def gatherData(self, data, rank=0):
"""
Function which gathers the data onto node with rank
rank
"""
#Get shape of data
ny, nx = data.shape
#Create list of buffers to return
retval = []
#If we are the target node, recieve from others
#otherwise send to target
if (self.comm.rank == rank):
mpi_requests = []
retval = []
#Loop over all nodes
for k in range(0, self.comm.size):
#If k equal target node, add our own data
#Otherwise receive it from node k
if (k == rank):
retval += [data]
else:
buffer = np.empty((ny, nx), dtype=np.float32)
retval += [buffer]
mpi_requests += [self.comm.Irecv(buffer, source=k, tag=k)]
#Wait for transfers to complete
for mpi_request in mpi_requests:
mpi_request.wait()
else:
mpi_request = self.comm.Isend(data, dest=rank, tag=self.comm.rank)
mpi_request.wait()
return retval
def gather(self, data, root=0):
out_data = None
if (self.comm.rank == root):
out_data = np.empty([self.comm.size] + list(data.shape), dtype=data.dtype)
self.comm.Gather(data, out_data, root)
return out_data
class MPISimulator(Simulator.BaseSimulator):
@@ -259,8 +215,8 @@ class MPISimulator(Simulator.BaseSimulator):
self.exchange()
self.sim.substep(dt, step_number)
def download(self):
return self.sim.download()
def getOutput(self):
return self.sim.getOutput()
def synchronize(self):
self.sim.synchronize()
@@ -273,7 +229,22 @@ class MPISimulator(Simulator.BaseSimulator):
global_dt = np.empty(1, dtype=np.float32)
self.grid.comm.Allreduce(local_dt, global_dt, op=MPI.MIN)
self.logger.debug("Local dt: {:f}, global dt: {:f}".format(local_dt[0], global_dt[0]))
return global_dt[0]
return global_dt[0]
def getExtent(self):
"""
Function which returns the extent of node with rank
rank in the grid
"""
width = self.sim.nx*self.sim.dx
height = self.sim.ny*self.sim.dy
i, j = self.grid.getCoordinate()
x0 = i * width
y0 = j * height
x1 = x0 + width
y1 = y0 + height
return [x0, x1, y0, y1]
def exchange(self):
####