MCPcopy Index your code
hub / github.com/apache/tvm / get_relax_conv2d_module

Function get_relax_conv2d_module

tests/python/relax/test_codegen_cudnn.py:56–104  ·  view source on GitHub ↗
(
    data_shape,
    weight_shape,
    dtype,
    with_bias=False,
    activation=None,
    residual_bin_op=None,
    residual_activation=None,
    data_layout="NHWC",
    kernel_layout="OHWI",
)

Source from the content-addressed store, hash-verified

54
55
56def get_relax_conv2d_module(
57 data_shape,
58 weight_shape,
59 dtype,
60 with_bias=False,
61 activation=None,
62 residual_bin_op=None,
63 residual_activation=None,
64 data_layout="NHWC",
65 kernel_layout="OHWI",
66):
67 with IRBuilder() as builder:
68 with relax_builder.function():
69 R.func_name("main")
70 data = R.arg("data", R.Tensor(data_shape, dtype))
71 weight = R.arg("weight", R.Tensor(weight_shape, dtype))
72 if with_bias:
73 if data_layout == "NHWC":
74 bias = R.arg("bias", R.Tensor((1, 1, 1, weight_shape[0]), dtype))
75 elif data_layout == "NCHW":
76 bias = R.arg("bias", R.Tensor((1, weight_shape[0], 1, 1), dtype))
77 else:
78 raise ValueError(f"Unsupported data_layout: {data_layout}")
79
80 with R.dataflow() as frame:
81 output = R.emit(
82 R.nn.conv2d(
83 data,
84 weight,
85 out_dtype=dtype,
86 padding=(1, 1),
87 data_layout=data_layout,
88 kernel_layout=kernel_layout,
89 )
90 )
91 if with_bias:
92 output = R.emit(output + bias)
93 if activation is not None:
94 output = R.emit(activation(output))
95 if residual_bin_op is not None:
96 output = R.emit(residual_bin_op(output, data))
97 if residual_activation is not None:
98 output = R.emit(residual_activation(output))
99 R.output(output)
100
101 R.func_ret_value(frame.output_vars[0])
102
103 func = builder.get()
104 return tvm.IRModule({"main": func})
105
106
107def get_result_with_relax_cudnn_offload(mod, np_inputs, cuda_graph=False):

Calls 8

IRBuilderClass · 0.90
TensorMethod · 0.80
dataflowMethod · 0.80
outputMethod · 0.80
functionMethod · 0.45
emitMethod · 0.45
conv2dMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…