fix(mpi): add synchronizing for streams

This commit is contained in:
Anthony Berg 2025-07-03 13:14:30 +02:00
parent fa3fcb76f8
commit 4dde38c2e5
6 changed files with 21 additions and 8 deletions

View File

@ -6,7 +6,7 @@ from .simulator import BaseMPISimulator
class CudaMPISimulator(BaseMPISimulator):
def __init__(self, sim, grid):
super().__init__(sim, grid)
super().__init__(sim, grid, self.__create_pagelocked_memory)
def __create_pagelocked_memory(self):
self.in_e = cuda.pagelocked_empty((int(self.nvars), int(self.read_e[3]), int(self.read_e[2])),

View File

@ -5,7 +5,7 @@ from .simulator import BaseMPISimulator
class HIPMPISimulator(BaseMPISimulator):
def __init__(self, sim, grid):
super().__init__(sim, grid)
super().__init__(sim, grid, self.__create_pagelocked_memory)
def __create_pagelocked_memory(self):
self.in_e = np.empty((int(self.nvars), int(self.read_e[3]), int(self.read_e[2])),

View File

@ -20,6 +20,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import logging
from typing import Callable
import numpy as np
from mpi4py import MPI
import time
@ -32,7 +34,7 @@ class BaseMPISimulator(BaseSimulator):
Class which handles communication between simulators on different MPI nodes
"""
def __init__(self, sim, grid):
def __init__(self, sim, grid, data_func: Callable):
self.profiling_data_mpi = {'start': {}, 'end': {}}
self.profiling_data_mpi["start"]["t_mpi_halo_exchange"] = 0
self.profiling_data_mpi["end"]["t_mpi_halo_exchange"] = 0
@ -124,7 +126,9 @@ class BaseMPISimulator(BaseSimulator):
self.out_w = None
self.out_n = None
self.out_s = None
self.__create_pagelocked_memory()
# Creates the page locked memory
data_func()
self.logger.debug(f"Simulator rank {self.grid.comm.rank} initialized on {MPI.Get_processor_name()}")
@ -154,8 +158,8 @@ class BaseMPISimulator(BaseSimulator):
self.full_exchange()
# nvtx.mark("sync start", color="blue")
self.sim.stream.synchronize()
self.sim.internal_stream.synchronize()
self.sim.synchronize()
self.sim.internal_synchronize()
# nvtx.mark("sync end", color="blue")
self.profiling_data_mpi["n_time_steps"] += 1
@ -205,7 +209,7 @@ class BaseMPISimulator(BaseSimulator):
if self.south is not None:
for k in range(self.nvars):
self.sim.u0[k].download(self.sim.stream, cpu_data=self.out_s[k, :, :], asynch=True, extent=self.read_s)
self.sim.stream.synchronize()
self.sim.synchronize()
self.profiling_data_mpi["end"]["t_mpi_halo_exchange_download"] += time.time()
@ -260,7 +264,7 @@ class BaseMPISimulator(BaseSimulator):
if self.west is not None:
for k in range(self.nvars):
self.sim.u0[k].download(self.sim.stream, cpu_data=self.out_w[k, :, :], asynch=True, extent=self.read_w)
self.sim.stream.synchronize()
self.sim.synchronize()
self.profiling_data_mpi["end"]["t_mpi_halo_exchange_download"] += time.time()

View File

@ -22,3 +22,6 @@ class CudaSimulator(simulator.BaseSimulator):
def synchronize(self):
self.stream.synchronize()
def internal_synchronize(self):
self.internal_stream.synchronize()

View File

@ -28,3 +28,6 @@ class HIPSimulator(simulator.BaseSimulator):
def synchronize(self):
hip_check(hip.hipStreamSynchronize(self.stream))
def internal_synchronize(self):
hip_check(hip.hipStreamSynchronize(self.internal_stream))

View File

@ -143,6 +143,9 @@ class BaseSimulator(object):
def synchronize(self):
raise NotImplementedError("Needs to be implemented in HIP/CUDA subclass")
def internal_synchronize(self):
raise NotImplementedError("Needs to be implemented in HIP/CUDA subclass")
def sim_time(self):
return self.t