refactor(simulator): change from CudaContext to KernelContext

This commit is contained in:
Anthony Berg 2025-06-24 20:37:24 +02:00
parent 3be1b074dd
commit bfed972046
5 changed files with 8 additions and 8 deletions

View File

@ -27,7 +27,7 @@ from IPython.core.magic import line_magic, Magics, magics_class
import pycuda.driver as cuda
from GPUSimulators.common import IPEngine
from GPUSimulators.gpu import CudaContext
from GPUSimulators.gpu import KernelContext
@magics_class
@ -59,7 +59,7 @@ class MagicCudaContext(Magics):
self.logger.debug("Creating context")
use_cache = False if args.no_cache else True
use_autotuning = False if args.no_autotuning else True
self.shell.user_ns[args.name] = CudaContext(context_flags=context_flags, use_cache=use_cache,
self.shell.user_ns[args.name] = KernelContext(context_flags=context_flags, use_cache=use_cache,
autotuning=use_autotuning)
# this function will be called on exceptions in any cell

View File

@ -20,7 +20,6 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import logging
from GPUSimulators import Simulator
import numpy as np
from mpi4py import MPI
import time
@ -28,6 +27,8 @@ import time
import pycuda.driver as cuda
#import nvtx
from GPUSimulators import Simulator
def get_grid(num_nodes, num_dims):
if not isinstance(num_nodes, int):

View File

@ -28,7 +28,7 @@ from enum import IntEnum
import pycuda.driver as cuda
from GPUSimulators.common import ProgressPrinter
from GPUSimulators.gpu import CudaContext
from GPUSimulators.gpu import KernelContext
def get_types(bc):
@ -107,7 +107,7 @@ class BoundaryCondition(object):
class BaseSimulator(object):
def __init__(self,
context: CudaContext,
context: KernelContext,
nx: int, ny: int,
dx: int, dy: int,
boundary_conditions: BoundaryCondition,

View File

@ -24,8 +24,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import numpy as np
from pycuda import gpuarray
from GPUSimulators.common import ArakawaA2D
from GPUSimulators import Simulator
from GPUSimulators.common import ArakawaA2D
from GPUSimulators.Simulator import BoundaryCondition

View File

@ -26,7 +26,6 @@ from pycuda import gpuarray
from GPUSimulators import Simulator
from GPUSimulators.common import ArakawaA2D
from GPUSimulators.gpu import CudaContext
from GPUSimulators.Simulator import BoundaryCondition
@ -36,7 +35,7 @@ class LxF(Simulator.BaseSimulator):
"""
def __init__(self,
context: CudaContext,
context,
h0: float, hu0: float, hv0: float,
nx: int, ny: int,
dx: int, dy: int,