mirror of
https://github.com/smyalygames/FiniteVolumeGPU.git
synced 2026-01-14 15:48:43 +01:00
Added example script
This commit is contained in:
@@ -44,8 +44,12 @@ Class which keeps track of the CUDA context and some helper functions
|
||||
"""
|
||||
class CudaContext(object):
|
||||
|
||||
def __init__(self, blocking=False, use_cache=True, autotuning=True):
|
||||
self.blocking = blocking
|
||||
def __init__(self, device=None, context_flags=None, use_cache=True, autotuning=True):
|
||||
"""
|
||||
Create a new CUDA context
|
||||
Set device to an id or pci_bus_id to select a specific GPU
|
||||
Set context_flags to cuda.ctx_flags.SCHED_BLOCKING_SYNC for a blocking context
|
||||
"""
|
||||
self.use_cache = use_cache
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.modules = {}
|
||||
@@ -60,17 +64,19 @@ class CudaContext(object):
|
||||
#Print some info about CUDA
|
||||
self.logger.info("CUDA version %s", str(cuda.get_version()))
|
||||
self.logger.info("Driver version %s", str(cuda.get_driver_version()))
|
||||
|
||||
self.cuda_device = cuda.Device(0)
|
||||
self.logger.info("Using '%s' GPU", self.cuda_device.name())
|
||||
|
||||
if device is None:
|
||||
device = 0
|
||||
|
||||
self.cuda_device = cuda.Device(device)
|
||||
self.logger.info("Using device %d/%d '%s' (%s) GPU", device, cuda.Device.count(), self.cuda_device.name(), self.cuda_device.pci_bus_id())
|
||||
self.logger.debug(" => compute capability: %s", str(self.cuda_device.compute_capability()))
|
||||
|
||||
# Create the CUDA context
|
||||
if (self.blocking):
|
||||
self.cuda_context = self.cuda_device.make_context(flags=cuda.ctx_flags.SCHED_BLOCKING_SYNC)
|
||||
self.logger.warning("Using blocking context")
|
||||
else:
|
||||
self.cuda_context = self.cuda_device.make_context(flags=cuda.ctx_flags.SCHED_AUTO)
|
||||
if context_flags is None:
|
||||
context_flags=cuda.ctx_flags.SCHED_AUTO
|
||||
|
||||
self.cuda_context = self.cuda_device.make_context(flags=context_flags)
|
||||
|
||||
free, total = cuda.mem_get_info()
|
||||
self.logger.debug(" => memory: %d / %d MB available", int(free/(1024*1024)), int(total/(1024*1024)))
|
||||
|
||||
Reference in New Issue
Block a user