MCPcopy
hub / github.com/compiler-explorer/compiler-explorer / main

Function main

etc/scripts/cutedsl_wrapper.py:231–279  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

229
230
231def main() -> None:
232 parser = argparse.ArgumentParser(description="CuTe DSL Compiler Explorer wrapper")
233 parser.add_argument("input_file", type=Path, help="Path to the input Python file")
234 parser.add_argument("--output_file", type=Path, required=True, help="Path to the output PTX file")
235 parser.add_argument("--arch", default="sm_90a", help="Value for CUTE_DSL_ARCH")
236 parser.add_argument("--keep", default=None, help="Value for CUTE_DSL_KEEP")
237 parser.add_argument("--python_path", default="", help="Extra PYTHONPATH entries separated by os.pathsep")
238 parser.add_argument("--no_runtime_patch", action="store_true", help="Allow compiled kernels to run")
239 args, source_args = parser.parse_known_args()
240
241 output_file = args.output_file
242 dump_dir = output_file.parent
243 dump_dir.mkdir(parents=True, exist_ok=True)
244 output_file.unlink(missing_ok=True)
245
246 add_python_paths(args.python_path)
247 add_python_paths(os.environ.get("CUTE_DSL_PYTHONPATH", ""))
248
249 os.environ["CUTE_DSL_ARCH"] = args.arch
250 os.environ["CUTE_DSL_DUMP_DIR"] = str(dump_dir)
251 os.environ["CUTE_DSL_KEEP"] = args.keep or os.environ.get("CUTE_DSL_KEEP", "ir,ptx,cubin")
252 os.environ.setdefault("CUTE_DSL_KEEP_IR", "1")
253 os.environ.setdefault("CUTE_DSL_KEEP_PTX", "1")
254 os.environ.setdefault("CUTE_DSL_KEEP_CUBIN", "1")
255 os.environ.setdefault("CUTE_DSL_NO_CACHE", "1")
256 os.environ.setdefault("CUTE_DSL_DISABLE_FILE_CACHING", "1")
257 os.environ.setdefault("CUDA_TOOLKIT_PATH", "/usr/local")
258
259 try:
260 import cutlass.cute as cute
261 except ModuleNotFoundError as exc:
262 raise SystemExit(
263 "CuTe DSL is not importable. Use a Python with cutlass.cute installed, "
264 f"or pass --python_path /path/to/cutlass/python/CuTeDSL. Missing module: {exc.name}."
265 ) from exc
266
267 patch_compile_defaults(cute, args.arch, dump_dir)
268
269 if not args.no_runtime_patch:
270 patch_runtime_calls()
271
272 old_argv = sys.argv
273 try:
274 sys.argv = [str(args.input_file), *source_args]
275 runpy.run_path(str(args.input_file), run_name="__main__")
276 finally:
277 sys.argv = old_argv
278
279 write_primary_output(output_file, dump_dir)
280
281
282if __name__ == "__main__":

Callers 1

cutedsl_wrapper.pyFile · 0.70

Calls 5

add_python_pathsFunction · 0.85
patch_compile_defaultsFunction · 0.85
patch_runtime_callsFunction · 0.85
write_primary_outputFunction · 0.85
getMethod · 0.65

Tested by

no test coverage detected