| 35 | |
| 36 | class CUDAProgram: |
| 37 | def __init__(self, dev:CUDADevice, name:str, lib:bytes, smem:int=0, **kwargs): |
| 38 | self.dev, self.name, self.lib, self.smem = dev, name, lib, smem |
| 39 | if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))])) |
| 40 | |
| 41 | check(cuda.cuCtxSetCurrent(self.dev.context)) |
| 42 | self.module = cuda.CUmodule() |
| 43 | status = cuda.cuModuleLoadData(ctypes.byref(self.module), lib) |
| 44 | if status != 0: |
| 45 | del self.module |
| 46 | raise RuntimeError(f"module load failed with status code {status}: {cuda.enum_cudaError_enum.get(status)}") |
| 47 | check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8"))) |
| 48 | self.prg = prg |
| 49 | if self.smem > 0: check(cuda.cuFuncSetAttribute(self.prg, cuda.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, self.smem)) |
| 50 | |
| 51 | @suppress_finalizing |
| 52 | def __del__(self): check(cuda.cuModuleUnload(self.module)) |