| 31 | |
| 32 | |
| 33 | def main() -> None: |
| 34 | parser = argparse.ArgumentParser(description="Output Triton code from public Helion kernels.") |
| 35 | parser.add_argument("--inputfile", required=True) |
| 36 | parser.add_argument("--outputfile", required=True) |
| 37 | args = parser.parse_args() |
| 38 | |
| 39 | try: |
| 40 | import helion |
| 41 | from helion.runtime.kernel import Kernel |
| 42 | |
| 43 | compiled_kernels: list[tuple[Kernel, object, str]] = [] |
| 44 | |
| 45 | # Patch kernel decorator to set autotune_effort='none' by default |
| 46 | original_kernel = helion.kernel |
| 47 | |
| 48 | def patched_kernel(*args, **kwargs): |
| 49 | if 'config' not in kwargs and 'autotune_effort' not in kwargs: |
| 50 | kwargs['autotune_effort'] = 'none' |
| 51 | return original_kernel(*args, **kwargs) |
| 52 | |
| 53 | helion.kernel = patched_kernel |
| 54 | |
| 55 | original_call = Kernel.__call__ |
| 56 | |
| 57 | def patched_call(self, *call_args, **call_kwargs): |
| 58 | result = original_call(self, *call_args, **call_kwargs) |
| 59 | |
| 60 | try: |
| 61 | bound = self.bind(call_args) |
| 62 | cfg = bound.config_spec.default_config() |
| 63 | triton_code = bound.to_triton_code(cfg) |
| 64 | compiled_kernels.append((self, call_args, triton_code)) |
| 65 | except Exception: |
| 66 | pass |
| 67 | |
| 68 | return result |
| 69 | |
| 70 | Kernel.__call__ = patched_call |
| 71 | |
| 72 | spec = importlib.util.spec_from_file_location("example", args.inputfile) |
| 73 | assert spec is not None and spec.loader is not None |
| 74 | mod = importlib.util.module_from_spec(spec) |
| 75 | spec.loader.exec_module(mod) |
| 76 | |
| 77 | Kernel.__call__ = original_call |
| 78 | |
| 79 | with open(args.outputfile, "w", encoding="utf-8") as out: |
| 80 | for kernel, args_used, triton_code in compiled_kernels: |
| 81 | out.write(triton_code) |
| 82 | out.write("\n\n") |
| 83 | |
| 84 | except Exception as error: |
| 85 | messages = [m for m in (getattr(error, "args", None) or [str(error)])] |
| 86 | with contextlib.suppress(Exception): |
| 87 | sys.stderr.writelines([str(m) + "\n" for m in messages]) |
| 88 | sys.exit(255) |
| 89 | |
| 90 | |