MCPcopy Index your code
hub / github.com/apache/tvm / conv2d

Method conv2d

tests/python/relax/test_transform_fuse_ops.py:1043–1059  ·  view source on GitHub ↗
(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(320), T.int64(320), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32"))

Source from the content-addressed store, hash-verified

1041
1042 @T.prim_func(private=True, s_tir=True)
1043 def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(320), T.int64(320), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32")):
1044 T.func_attr({"op_pattern": 4, "tirx.noalias": True})
1045 pad_temp = T.sblock_alloc_buffer((T.int64(2), T.int64(320), T.int64(66), T.int64(66)))
1046 for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(320), T.int64(66), T.int64(66)):
1047 with T.sblock("pad_temp"):
1048 v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
1049 T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)])
1050 T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
1051 pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(65) and T.int64(1) <= v_i3 and v_i3 < T.int64(65), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float32(0))
1052 for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(2), T.int64(320), T.int64(64), T.int64(64), T.int64(320), T.int64(3), T.int64(3)):
1053 with T.sblock("conv2d_nchw"):
1054 v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
1055 T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], rxplaceholder_1[v_ff, v_rc, v_ry, v_rx])
1056 T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
1057 with T.init():
1058 conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0)
1059 conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * rxplaceholder_1[v_ff, v_rc, v_ry, v_rx]
1060
1061 @T.prim_func(private=True, s_tir=True)
1062 def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(1280)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1280), T.int64(320)), "float32"), matmul: T.Buffer((T.int64(2), T.int64(320)), "float32")):

Callers 15

_conv2d_implMethod · 0.45
_convolutionMethod · 0.45
conv2dFunction · 0.45
forwardMethod · 0.45
convert_convMethod · 0.45
conv2dFunction · 0.45
mainMethod · 0.45
forwardMethod · 0.45
mainMethod · 0.45
mainMethod · 0.45
mainMethod · 0.45

Calls 2

remapMethod · 0.80
initMethod · 0.45

Tested by

no test coverage detected