MCPcopy Index your code
hub / github.com/apache/tvm / extern_primfunc

Function extern_primfunc

python/tvm/te/operation.py:353–434  ·  view source on GitHub ↗

Compute tensors via a schedulable TIR PrimFunc Parameters ---------- input_tensors: list of Tensor Input tensors that map to the corresponding primfunc input params. primfunc: PrimFunc The TIR PrimFunc Returns ------- tensor: Tensor or list of Tensors

(input_tensors: list[_tensor.Tensor], primfunc: tvm.tirx.PrimFunc, **kwargs)

Source from the content-addressed store, hash-verified

351
352
353def extern_primfunc(input_tensors: list[_tensor.Tensor], primfunc: tvm.tirx.PrimFunc, **kwargs):
354 """Compute tensors via a schedulable TIR PrimFunc
355
356 Parameters
357 ----------
358 input_tensors: list of Tensor
359 Input tensors that map to the corresponding primfunc input params.
360
361 primfunc: PrimFunc
362 The TIR PrimFunc
363
364 Returns
365 -------
366 tensor: Tensor or list of Tensors
367 The created tensor or tuple of tensors if it contains multiple outputs.
368
369 Example
370 -------
371 In the code below, a TVMScript defined TIR PrimFunc is inlined into
372 a TE ExternOp. Applying te.create_prim_func on this
373
374 .. code-block:: python
375
376 A = te.placeholder((128, 128), name="A")
377 B = te.placeholder((128, 128), name="B")
378
379 @T.prim_func(s_tir=True)
380 def before_split(a: T.handle, b: T.handle) -> None:
381 A = T.match_buffer(a, (128, 128))
382 B = T.match_buffer(b, (128, 128))
383 for i, j in T.grid(128, 128):
384 with T.sblock("B"):
385 vi, vj = T.axis.remap("SS", [i, j])
386 B[vi, vj] = A[vi, vj] * 2.0
387
388 C = te.extern_primfunc([A, B], func)
389 """
390
391 # dt_access_map and primfunc.buffer_map are unordered, so use order from primfunc.params
392 dt_access_map = tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc)
393 ordered_buffers = [primfunc.buffer_map[param] for param in primfunc.params]
394 in_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][0])]
395 out_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][1])]
396 assert in_buffers, "PrimFunc has no input buffers"
397 assert out_buffers, "PrimFunc has no output buffers"
398
399 outputs = []
400 inplace = []
401 input_buffers = in_buffers
402 for obuf in out_buffers:
403 if obuf in in_buffers:
404 inplace.append(obuf)
405 else:
406 outputs.append(obuf)
407
408 if not outputs:
409 iobuf = inplace.pop()
410 input_buffers.remove(iobuf)

Callers

nothing calls this directly

Calls 4

externFunction · 0.70
appendMethod · 0.45
popMethod · 0.45
removeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…