fix(kernel): type for defines in get_module

This commit is contained in:
Anthony Berg 2025-06-24 17:45:50 +02:00
parent ae23145753
commit d2544e7c55
2 changed files with 2 additions and 2 deletions

View File

@ -182,7 +182,7 @@ class CudaContext(object):
def get_module(self, kernel_filename: str,
include_dirs: dict=None,
defines:list[str]=None,
defines:dict[str: int]=None,
compile_args:dict=None, jit_compile_args:dict=None) -> cuda.Module:
"""
Reads a text file and creates an OpenCL kernel from that.

View File

@ -58,7 +58,7 @@ class HIPContext(Context):
def get_module(self, kernel_filename: str,
include_dirs: dict=None,
defines:list[str]=None,
defines:dict[str: int]=None,
compile_args:dict=None,
jit_compile_args:dict=None):
"""