feat(mpi): calculating the grid and assigning them correct variables

This commit is contained in:
Anthony Berg 2025-08-08 00:01:04 +02:00
parent 8da34a01f8
commit 90a5ff81a8
2 changed files with 16 additions and 11 deletions

View File

@ -158,8 +158,9 @@ def run_simulation(simulator, simulator_args, outfile, save_times, save_var_name
ncvars['time'][:] = save_times ncvars['time'][:] = save_times
ncvars['time'].units = "s" ncvars['time'].units = "s"
x0, x1, y0, y1 = sim.get_extent() x0, x1, y0, y1 = sim.get_extent()
ncvars['x'][grid_x0:grid_x1] = np.linspace(x0, x1, simulator_args['nx'])
ncvars['y'][grid_y0:grid_y1] = np.linspace(y0, y1, simulator_args['ny']) ncvars['x'][grid_x0:grid_x1] = np.linspace(grid_x0, grid_x1-1, simulator_args['nx'])
ncvars['y'][grid_y0:grid_y1] = np.linspace(grid_y0, grid_y1-1, simulator_args['ny'])
# Choose which variables to download (prune None from the list, but keep the index) # Choose which variables to download (prune None from the list, but keep the index)
download_vars = [] download_vars = []
@ -208,7 +209,7 @@ def run_simulation(simulator, simulator_args, outfile, save_times, save_var_name
# Save to file # Save to file
for i, var_name in enumerate(save_var_names): for i, var_name in enumerate(save_var_names):
ncvars[var_name][save_step, grid_y0:grid_y1] = save_vars[i] ncvars[var_name][save_step, grid_y0:grid_y1, grid_x0:grid_x1] = save_vars[i]
profiling_data_sim_runner["end"]["t_nc_write"] += time.time() profiling_data_sim_runner["end"]["t_nc_write"] += time.time()

View File

@ -88,18 +88,22 @@ class MPIGrid(object):
if comm.size < 1: if comm.size < 1:
raise ValueError("Must have at least one node") raise ValueError("Must have at least one node")
grid = get_grid(comm.size, ndims) grid_x, grid_y = get_grid(comm.size, ndims)
self.x = grid[0] self.x = grid_x
self.y = grid[1] self.y = grid_y
self.x0 = nx * (self.x-1)
self.x1 = self.x0 + nx
self.y0 = ny * (self.y-1)
self.y1 = self.y0 + ny
self.comm = comm self.comm = comm
x, y = self.get_coordinate()
self.x0 = nx * x
self.x1 = self.x0 + nx
self.y0 = ny * y
self.y1 = self.y0 + ny
self.logger.debug( self.logger.debug(
f"Created MPI grid: {grid}. Rank {self.comm.rank} has coordinate {self.get_coordinate()}") f"Created MPI grid: ({grid_x}, {grid_y}). Rank {self.comm.rank} has coordinate: ({x}, {y})")
def get_coordinate(self, rank=None): def get_coordinate(self, rank=None):
if rank is None: if rank is None: