refactor(GPUSimulator): follow PEP 8 style guide

This commit is contained in:
Anthony Berg
2025-02-14 12:40:31 +01:00
parent ce8e834771
commit ef207432db
17 changed files with 286 additions and 354 deletions

View File

@@ -19,8 +19,6 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import os
import numpy as np
@@ -38,11 +36,10 @@ import pycuda.driver as cuda
from GPUSimulators import Autotuner, Common
"""
Class which keeps track of the CUDA context and some helper functions
"""
class CudaContext(object):
"""
Class which keeps track of the CUDA context and some helper functions
"""
def __init__(self, device=None, context_flags=None, use_cache=True, autotuning=True):
"""
@@ -50,6 +47,7 @@ class CudaContext(object):
Set device to an id or pci_bus_id to select a specific GPU
Set context_flags to cuda.ctx_flags.SCHED_BLOCKING_SYNC for a blocking context
"""
self.use_cache = use_cache
self.logger = logging.getLogger(__name__)
self.modules = {}
@@ -94,7 +92,6 @@ class CudaContext(object):
if (autotuning):
self.logger.info("Autotuning enabled. It may take several minutes to run the code the first time: have patience")
self.autotuner = Autotuner.Autotuner()
def __del__(self, *args):
self.logger.info("Cleaning up CUDA context handle <%s>", str(self.cuda_context.handle))
@@ -119,10 +116,8 @@ class CudaContext(object):
self.logger.debug("<%s> Detaching", str(self.cuda_context.handle))
self.cuda_context.detach()
def __str__(self):
return "CudaContext id " + str(self.cuda_context.handle)
def hash_kernel(kernel_filename, include_dirs):
# Generate a kernel ID for our caches
@@ -171,18 +166,19 @@ class CudaContext(object):
return kernel_hasher.hexdigest()
"""
Reads a text file and creates an OpenCL kernel from that
"""
def get_module(self, kernel_filename,
include_dirs=[], \
defines={}, \
compile_args={'no_extern_c', True}, jit_compile_args={}):
"""
Helper function to print compilation output
Reads a text file and creates an OpenCL kernel from that
"""
def cuda_compile_message_handler(compile_success_bool, info_str, error_str):
"""
Helper function to print compilation output
"""
self.logger.debug("Compilation returned %s", str(compile_success_bool))
if info_str:
self.logger.debug("Info: %s", info_str)
@@ -257,16 +253,18 @@ class CudaContext(object):
self.modules[kernel_hash] = module
return module
"""
Clears the kernel cache (useful for debugging & development)
"""
def clear_kernel_cache(self):
"""
Clears the kernel cache (useful for debugging & development)
"""
self.logger.debug("Clearing cache")
self.modules = {}
gc.collect()
"""
Synchronizes all streams etc
"""
def synchronize(self):
"""
Synchronizes all streams etc
"""
self.cuda_context.synchronize()