MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / define_kernel

Function define_kernel

triton_kernels/specialize.py:33–66  ·  view source on GitHub ↗

Dynamically create a Triton function or kernel from a src string, linking any symbols in the kernel to objects specified by extra_globals.

(src, module, attrs=None, **extra_globals)

Source from the content-addressed store, hash-verified

31
32
33def define_kernel(src, module, attrs=None, **extra_globals):
34 """
35 Dynamically create a Triton function or kernel from a src string,
36 linking any symbols in the kernel to objects specified by extra_globals.
37 """
38
39 # create templace function
40 def _empty_fn():
41 pass
42
43 gdict = dict(**(_empty_fn.__globals__))
44 gdict.update(extra_globals)
45 f = types.FunctionType(_empty_fn.__code__, gdict)
46 f.__module__ = module.__name__
47
48 src = textwrap.dedent(src)
49 src = src[src.find("def "):]
50
51 stored_functions = []
52 function_name = src[4:].split("(")[0].strip()
53
54 exec_globals = gdict
55 exec_globals.update({"stored_functions": stored_functions})
56 exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals)
57
58 f.__signature__ = inspect.signature(stored_functions[0])
59 f.__name__ = function_name
60 f.__doc__ = stored_functions[0].__doc__
61
62 if attrs is None:
63 attrs = dict()
64 f = triton.JITFunction(f, **attrs)
65 f._unsafe_update_src(src)
66 return f
67
68
69def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()):

Callers 1

specializeFunction · 0.85

Calls 2

updateMethod · 0.45
splitMethod · 0.45

Tested by

no test coverage detected