Compute tensors via a schedulable TIR PrimFunc Parameters ---------- input_tensors: list of Tensor Input tensors that map to the corresponding primfunc input params. primfunc: PrimFunc The TIR PrimFunc Returns ------- tensor: Tensor or list of Tensors
(input_tensors: list[_tensor.Tensor], primfunc: tvm.tirx.PrimFunc, **kwargs)
| 351 | |
| 352 | |
| 353 | def extern_primfunc(input_tensors: list[_tensor.Tensor], primfunc: tvm.tirx.PrimFunc, **kwargs): |
| 354 | """Compute tensors via a schedulable TIR PrimFunc |
| 355 | |
| 356 | Parameters |
| 357 | ---------- |
| 358 | input_tensors: list of Tensor |
| 359 | Input tensors that map to the corresponding primfunc input params. |
| 360 | |
| 361 | primfunc: PrimFunc |
| 362 | The TIR PrimFunc |
| 363 | |
| 364 | Returns |
| 365 | ------- |
| 366 | tensor: Tensor or list of Tensors |
| 367 | The created tensor or tuple of tensors if it contains multiple outputs. |
| 368 | |
| 369 | Example |
| 370 | ------- |
| 371 | In the code below, a TVMScript defined TIR PrimFunc is inlined into |
| 372 | a TE ExternOp. Applying te.create_prim_func on this |
| 373 | |
| 374 | .. code-block:: python |
| 375 | |
| 376 | A = te.placeholder((128, 128), name="A") |
| 377 | B = te.placeholder((128, 128), name="B") |
| 378 | |
| 379 | @T.prim_func(s_tir=True) |
| 380 | def before_split(a: T.handle, b: T.handle) -> None: |
| 381 | A = T.match_buffer(a, (128, 128)) |
| 382 | B = T.match_buffer(b, (128, 128)) |
| 383 | for i, j in T.grid(128, 128): |
| 384 | with T.sblock("B"): |
| 385 | vi, vj = T.axis.remap("SS", [i, j]) |
| 386 | B[vi, vj] = A[vi, vj] * 2.0 |
| 387 | |
| 388 | C = te.extern_primfunc([A, B], func) |
| 389 | """ |
| 390 | |
| 391 | # dt_access_map and primfunc.buffer_map are unordered, so use order from primfunc.params |
| 392 | dt_access_map = tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc) |
| 393 | ordered_buffers = [primfunc.buffer_map[param] for param in primfunc.params] |
| 394 | in_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][0])] |
| 395 | out_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][1])] |
| 396 | assert in_buffers, "PrimFunc has no input buffers" |
| 397 | assert out_buffers, "PrimFunc has no output buffers" |
| 398 | |
| 399 | outputs = [] |
| 400 | inplace = [] |
| 401 | input_buffers = in_buffers |
| 402 | for obuf in out_buffers: |
| 403 | if obuf in in_buffers: |
| 404 | inplace.append(obuf) |
| 405 | else: |
| 406 | outputs.append(obuf) |
| 407 | |
| 408 | if not outputs: |
| 409 | iobuf = inplace.pop() |
| 410 | input_buffers.remove(iobuf) |