Added example script

This commit is contained in:
André R. Brodtkorb
2018-12-04 18:06:45 +01:00
parent 4292513c03
commit 12174b39db
3 changed files with 203 additions and 13 deletions

View File

@@ -44,8 +44,12 @@ Class which keeps track of the CUDA context and some helper functions
"""
class CudaContext(object):
def __init__(self, blocking=False, use_cache=True, autotuning=True):
self.blocking = blocking
def __init__(self, device=None, context_flags=None, use_cache=True, autotuning=True):
"""
Create a new CUDA context
Set device to an id or pci_bus_id to select a specific GPU
Set context_flags to cuda.ctx_flags.SCHED_BLOCKING_SYNC for a blocking context
"""
self.use_cache = use_cache
self.logger = logging.getLogger(__name__)
self.modules = {}
@@ -60,17 +64,19 @@ class CudaContext(object):
#Print some info about CUDA
self.logger.info("CUDA version %s", str(cuda.get_version()))
self.logger.info("Driver version %s", str(cuda.get_driver_version()))
self.cuda_device = cuda.Device(0)
self.logger.info("Using '%s' GPU", self.cuda_device.name())
if device is None:
device = 0
self.cuda_device = cuda.Device(device)
self.logger.info("Using device %d/%d '%s' (%s) GPU", device, cuda.Device.count(), self.cuda_device.name(), self.cuda_device.pci_bus_id())
self.logger.debug(" => compute capability: %s", str(self.cuda_device.compute_capability()))
# Create the CUDA context
if (self.blocking):
self.cuda_context = self.cuda_device.make_context(flags=cuda.ctx_flags.SCHED_BLOCKING_SYNC)
self.logger.warning("Using blocking context")
else:
self.cuda_context = self.cuda_device.make_context(flags=cuda.ctx_flags.SCHED_AUTO)
if context_flags is None:
context_flags=cuda.ctx_flags.SCHED_AUTO
self.cuda_context = self.cuda_device.make_context(flags=context_flags)
free, total = cuda.mem_get_info()
self.logger.debug(" => memory: %d / %d MB available", int(free/(1024*1024)), int(total/(1024*1024)))