refactor(mpi): follow PEP8 scheme and replace .format() with f strings

This commit is contained in:
Anthony Berg 2025-06-24 17:55:23 +02:00
parent 7475d590eb
commit 0633561bbf

View File

@ -29,6 +29,75 @@ import pycuda.driver as cuda
#import nvtx
def get_grid(num_nodes, num_dims):
if not isinstance(num_nodes, int):
raise TypeError("Parameter `num_nodes` is not a an integer.")
if not isinstance(num_dims, int):
raise TypeError("Parameter `num_dims` is not a an integer.")
# Adapted from https://stackoverflow.com/questions/28057307/factoring-a-number-into-roughly-equal-factors
# Original code by https://stackoverflow.com/users/3928385/ishamael
# Factorizes a number into n roughly equal factors
#Dictionary to remember already computed permutations
memo = {}
def dp(n, left): # returns tuple (cost, [factors])
"""
Recursively searches through all factorizations
"""
#Already tried: return an existing result
if (n, left) in memo:
return memo[(n, left)]
#Spent all factors: return number itself
if left == 1:
return (n, [n])
#Find a new factor
i = 2
best = n
best_tuple = [n]
while i * i < n:
#If a factor found
if n % i == 0:
#Factorize remainder
rem = dp(n // i, left - 1)
#If new permutation better, save it
if rem[0] + i < best:
best = rem[0] + i
best_tuple = [i] + rem[1]
i += 1
#Store calculation
memo[(n, left)] = (best, best_tuple)
return memo[(n, left)]
grid = dp(num_nodes, num_dims)[1]
if len(grid) < num_dims:
#Split problematic 4
if 4 in grid:
grid.remove(4)
grid.append(2)
grid.append(2)
#Pad with ones to guarantee num_dims
grid = grid + [1]*(num_dims - len(grid))
#Sort in descending order
grid = np.sort(grid)
grid = grid[::-1]
# XXX: We only use vertical (north-south) partitioning for now
grid[0] = 1
grid[1] = num_nodes
return grid
class MPIGrid(object):
"""
Class which represents an MPI grid of nodes. Facilitates easy communication between
@ -37,15 +106,16 @@ class MPIGrid(object):
def __init__(self, comm, ndims=2):
self.logger = logging.getLogger(__name__)
if ndims != 2:
raise ValueError("Unsupported number of dimensions. Must be two at the moment")
if comm.size < 1:
raise ValueError("Must have at least one node")
assert ndims == 2, "Unsupported number of dimensions. Must be two at the moment"
assert comm.size >= 1, "Must have at least one node"
self.grid = MPIGrid.get_grid(comm.size, ndims)
self.grid = get_grid(comm.size, ndims)
self.comm = comm
self.logger.debug("Created MPI grid: {:}. Rank {:d} has coordinate {:}".format(
self.grid, self.comm.rank, self.getCoordinate()))
self.logger.debug(f"Created MPI grid: {self.grid}. Rank {self.comm.rank} has coordinate {self.get_coordinate()}")
def get_coordinate(self, rank=None):
if rank is None:
@ -76,76 +146,10 @@ class MPIGrid(object):
i, j = self.get_coordinate(self.comm.rank)
j = (j+self.grid[1]-1) % self.grid[1]
return self.get_rank(i, j)
def get_grid(num_nodes, num_dims):
assert(isinstance(num_nodes, int))
assert(isinstance(num_dims, int))
# Adapted from https://stackoverflow.com/questions/28057307/factoring-a-number-into-roughly-equal-factors
# Original code by https://stackoverflow.com/users/3928385/ishamael
# Factorizes a number into n roughly equal factors
#Dictionary to remember already computed permutations
memo = {}
def dp(n, left): # returns tuple (cost, [factors])
"""
Recursively searches through all factorizations
"""
#Already tried: return existing result
if (n, left) in memo:
return memo[(n, left)]
#Spent all factors: return number itself
if left == 1:
return (n, [n])
#Find new factor
i = 2
best = n
bestTuple = [n]
while i * i < n:
#If factor found
if n % i == 0:
#Factorize remainder
rem = dp(n // i, left - 1)
#If new permutation better, save it
if rem[0] + i < best:
best = rem[0] + i
bestTuple = [i] + rem[1]
i += 1
#Store calculation
memo[(n, left)] = (best, bestTuple)
return memo[(n, left)]
grid = dp(num_nodes, num_dims)[1]
if (len(grid) < num_dims):
#Split problematic 4
if (4 in grid):
grid.remove(4)
grid.append(2)
grid.append(2)
#Pad with ones to guarantee num_dims
grid = grid + [1]*(num_dims - len(grid))
#Sort in descending order
grid = np.sort(grid)
grid = grid[::-1]
# XXX: We only use vertical (north-south) partitioning for now
grid[0] = 1
grid[1] = num_nodes
return grid
def gather(self, data, root=0):
out_data = None
if (self.comm.rank == root):
if self.comm.rank == root:
out_data = np.empty([self.comm.size] + list(data.shape), dtype=data.dtype)
self.comm.Gather(data, out_data, root)
return out_data
@ -206,7 +210,7 @@ class MPISimulator(Simulator.BaseSimulator):
"""
def __init__(self, sim, grid):
self.profiling_data_mpi = { 'start': {}, 'end': {} }
self.profiling_data_mpi = {'start': {}, 'end': {}}
self.profiling_data_mpi["start"]["t_mpi_halo_exchange"] = 0
self.profiling_data_mpi["end"]["t_mpi_halo_exchange"] = 0
self.profiling_data_mpi["start"]["t_mpi_halo_exchange_download"] = 0
@ -221,7 +225,7 @@ class MPISimulator(Simulator.BaseSimulator):
self.logger = logging.getLogger(__name__)
autotuner = sim.context.autotuner
sim.context.autotuner = None;
sim.context.autotuner = None
boundary_conditions = sim.get_boundary_conditions()
super().__init__(sim.context,
sim.nx, sim.ny,
@ -251,18 +255,18 @@ class MPISimulator(Simulator.BaseSimulator):
})
gi, gj = grid.get_coordinate()
#print("gi: " + str(gi) + ", gj: " + str(gj))
if (gi == 0 and boundary_conditions.west != Simulator.BoundaryCondition.Type.Periodic):
if gi == 0 and boundary_conditions.west != Simulator.BoundaryCondition.Type.Periodic:
self.west = None
new_boundary_conditions.west = boundary_conditions.west;
if (gj == 0 and boundary_conditions.south != Simulator.BoundaryCondition.Type.Periodic):
new_boundary_conditions.west = boundary_conditions.west
if gj == 0 and boundary_conditions.south != Simulator.BoundaryCondition.Type.Periodic:
self.south = None
new_boundary_conditions.south = boundary_conditions.south;
if (gi == grid.grid[0]-1 and boundary_conditions.east != Simulator.BoundaryCondition.Type.Periodic):
new_boundary_conditions.south = boundary_conditions.south
if gi == grid.grid[0]-1 and boundary_conditions.east != Simulator.BoundaryCondition.Type.Periodic:
self.east = None
new_boundary_conditions.east = boundary_conditions.east;
if (gj == grid.grid[1]-1 and boundary_conditions.north != Simulator.BoundaryCondition.Type.Periodic):
new_boundary_conditions.east = boundary_conditions.east
if gj == grid.grid[1]-1 and boundary_conditions.north != Simulator.BoundaryCondition.Type.Periodic:
self.north = None
new_boundary_conditions.north = boundary_conditions.north;
new_boundary_conditions.north = boundary_conditions.north
sim.set_boundary_conditions(new_boundary_conditions)
#Get number of variables
@ -302,7 +306,7 @@ class MPISimulator(Simulator.BaseSimulator):
self.out_n = cuda.pagelocked_empty((int(self.nvars), int(self.read_n[3]), int(self.read_n[2])), dtype=np.float32) #np.empty_like(self.in_n)
self.out_s = cuda.pagelocked_empty((int(self.nvars), int(self.read_s[3]), int(self.read_s[2])), dtype=np.float32) #np.empty_like(self.in_s)
self.logger.debug("Simlator rank {:d} initialized on {:s}".format(self.grid.comm.rank, MPI.Get_processor_name()))
self.logger.debug(f"Simulator rank {self.grid.comm.rank} initialized on {MPI.Get_processor_name()}")
self.full_exchange()
sim.context.synchronize()
@ -346,16 +350,16 @@ class MPISimulator(Simulator.BaseSimulator):
return self.sim.check()
def compute_dt(self):
local_dt = np.array([np.float32(self.sim.compute_dt())]);
local_dt = np.array([np.float32(self.sim.compute_dt())])
global_dt = np.empty(1, dtype=np.float32)
self.grid.comm.Allreduce(local_dt, global_dt, op=MPI.MIN)
self.logger.debug("Local dt: {:f}, global dt: {:f}".format(local_dt[0], global_dt[0]))
self.logger.debug(f"Local dt: {local_dt[0]}, global dt: {global_dt[0]}")
return global_dt[0]
def get_extent(self):
"""
Function which returns the extent of node with rank
rank in the grid
in the grid
"""
width = self.sim.nx*self.sim.dx
@ -385,7 +389,7 @@ class MPISimulator(Simulator.BaseSimulator):
self.profiling_data_mpi["end"]["t_mpi_halo_exchange_download"] += time.time()
#Send/receive to north/south neighbours
#Send/receive to north/south neighbors
self.profiling_data_mpi["start"]["t_mpi_halo_exchange_sendreceive"] += time.time()
comm_send = []