| 275 | |
| 276 | if op_type == 'cd': |
| 277 | def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): |
| 278 | assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2' |
| 279 | assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3' |
| 280 | assert padding == dilation, 'padding for cd_conv set wrong' |
| 281 | |
| 282 | weights_c = weights.sum(dim=[2, 3], keepdim=True) |
| 283 | yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups) |
| 284 | y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) |
| 285 | return y - yc |
| 286 | return func |
| 287 | elif op_type == 'ad': |
| 288 | def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): |