Dynamically create a Triton function or kernel from a src string, linking any symbols in the kernel to objects specified by extra_globals.
(src, module, attrs=None, **extra_globals)
| 31 | |
| 32 | |
| 33 | def define_kernel(src, module, attrs=None, **extra_globals): |
| 34 | """ |
| 35 | Dynamically create a Triton function or kernel from a src string, |
| 36 | linking any symbols in the kernel to objects specified by extra_globals. |
| 37 | """ |
| 38 | |
| 39 | # create templace function |
| 40 | def _empty_fn(): |
| 41 | pass |
| 42 | |
| 43 | gdict = dict(**(_empty_fn.__globals__)) |
| 44 | gdict.update(extra_globals) |
| 45 | f = types.FunctionType(_empty_fn.__code__, gdict) |
| 46 | f.__module__ = module.__name__ |
| 47 | |
| 48 | src = textwrap.dedent(src) |
| 49 | src = src[src.find("def "):] |
| 50 | |
| 51 | stored_functions = [] |
| 52 | function_name = src[4:].split("(")[0].strip() |
| 53 | |
| 54 | exec_globals = gdict |
| 55 | exec_globals.update({"stored_functions": stored_functions}) |
| 56 | exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals) |
| 57 | |
| 58 | f.__signature__ = inspect.signature(stored_functions[0]) |
| 59 | f.__name__ = function_name |
| 60 | f.__doc__ = stored_functions[0].__doc__ |
| 61 | |
| 62 | if attrs is None: |
| 63 | attrs = dict() |
| 64 | f = triton.JITFunction(f, **attrs) |
| 65 | f._unsafe_update_src(src) |
| 66 | return f |
| 67 | |
| 68 | |
| 69 | def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()): |
no test coverage detected