()
| 229 | |
| 230 | |
| 231 | def main() -> None: |
| 232 | parser = argparse.ArgumentParser(description="CuTe DSL Compiler Explorer wrapper") |
| 233 | parser.add_argument("input_file", type=Path, help="Path to the input Python file") |
| 234 | parser.add_argument("--output_file", type=Path, required=True, help="Path to the output PTX file") |
| 235 | parser.add_argument("--arch", default="sm_90a", help="Value for CUTE_DSL_ARCH") |
| 236 | parser.add_argument("--keep", default=None, help="Value for CUTE_DSL_KEEP") |
| 237 | parser.add_argument("--python_path", default="", help="Extra PYTHONPATH entries separated by os.pathsep") |
| 238 | parser.add_argument("--no_runtime_patch", action="store_true", help="Allow compiled kernels to run") |
| 239 | args, source_args = parser.parse_known_args() |
| 240 | |
| 241 | output_file = args.output_file |
| 242 | dump_dir = output_file.parent |
| 243 | dump_dir.mkdir(parents=True, exist_ok=True) |
| 244 | output_file.unlink(missing_ok=True) |
| 245 | |
| 246 | add_python_paths(args.python_path) |
| 247 | add_python_paths(os.environ.get("CUTE_DSL_PYTHONPATH", "")) |
| 248 | |
| 249 | os.environ["CUTE_DSL_ARCH"] = args.arch |
| 250 | os.environ["CUTE_DSL_DUMP_DIR"] = str(dump_dir) |
| 251 | os.environ["CUTE_DSL_KEEP"] = args.keep or os.environ.get("CUTE_DSL_KEEP", "ir,ptx,cubin") |
| 252 | os.environ.setdefault("CUTE_DSL_KEEP_IR", "1") |
| 253 | os.environ.setdefault("CUTE_DSL_KEEP_PTX", "1") |
| 254 | os.environ.setdefault("CUTE_DSL_KEEP_CUBIN", "1") |
| 255 | os.environ.setdefault("CUTE_DSL_NO_CACHE", "1") |
| 256 | os.environ.setdefault("CUTE_DSL_DISABLE_FILE_CACHING", "1") |
| 257 | os.environ.setdefault("CUDA_TOOLKIT_PATH", "/usr/local") |
| 258 | |
| 259 | try: |
| 260 | import cutlass.cute as cute |
| 261 | except ModuleNotFoundError as exc: |
| 262 | raise SystemExit( |
| 263 | "CuTe DSL is not importable. Use a Python with cutlass.cute installed, " |
| 264 | f"or pass --python_path /path/to/cutlass/python/CuTeDSL. Missing module: {exc.name}." |
| 265 | ) from exc |
| 266 | |
| 267 | patch_compile_defaults(cute, args.arch, dump_dir) |
| 268 | |
| 269 | if not args.no_runtime_patch: |
| 270 | patch_runtime_calls() |
| 271 | |
| 272 | old_argv = sys.argv |
| 273 | try: |
| 274 | sys.argv = [str(args.input_file), *source_args] |
| 275 | runpy.run_path(str(args.input_file), run_name="__main__") |
| 276 | finally: |
| 277 | sys.argv = old_argv |
| 278 | |
| 279 | write_primary_output(output_file, dump_dir) |
| 280 | |
| 281 | |
| 282 | if __name__ == "__main__": |
no test coverage detected