| 146 | |
| 147 | |
| 148 | def get_relax_conv2d_module( |
| 149 | data_shape, |
| 150 | weight_shape, |
| 151 | dtype, |
| 152 | with_bias=False, |
| 153 | activation=None, |
| 154 | residual_bin_op=None, |
| 155 | residual_activation=None, |
| 156 | ): |
| 157 | with IRBuilder() as builder: |
| 158 | with relax_builder.function(): |
| 159 | R.func_name("main") |
| 160 | data = R.arg("data", R.Tensor(data_shape, dtype)) |
| 161 | weight = R.arg("weight", R.Tensor(weight_shape, dtype)) |
| 162 | if with_bias: |
| 163 | bias = R.arg("bias", R.Tensor((1, 1, 1, weight_shape[0]), dtype)) |
| 164 | |
| 165 | with R.dataflow() as frame: |
| 166 | output = R.emit( |
| 167 | R.nn.conv2d( |
| 168 | data, |
| 169 | weight, |
| 170 | out_dtype=dtype, |
| 171 | padding=(1, 1), |
| 172 | data_layout="NHWC", |
| 173 | kernel_layout="OHWI", |
| 174 | ) |
| 175 | ) |
| 176 | if with_bias: |
| 177 | output = R.emit(output + bias) |
| 178 | if activation is not None: |
| 179 | output = R.emit(activation(output)) |
| 180 | if residual_bin_op is not None: |
| 181 | output = R.emit(residual_bin_op(output, data)) |
| 182 | if residual_activation is not None: |
| 183 | output = R.emit(residual_activation(output)) |
| 184 | R.output(output) |
| 185 | |
| 186 | R.func_ret_value(frame.output_vars[0]) |
| 187 | |
| 188 | func = builder.get() |
| 189 | return tvm.IRModule({"main": func}) |
| 190 | |
| 191 | |
| 192 | def _to_concrete_shape(symbolic_shape, var_table=None): |