| 18 | |
| 19 | |
| 20 | class Ternary(nn.Module): |
| 21 | def __init__(self): |
| 22 | super(Ternary, self).__init__() |
| 23 | patch_size = 7 |
| 24 | out_channels = patch_size * patch_size |
| 25 | self.w = np.eye(out_channels).reshape( |
| 26 | (patch_size, patch_size, 1, out_channels)) |
| 27 | self.w = np.transpose(self.w, (3, 2, 0, 1)) |
| 28 | self.w = torch.tensor(self.w).float().to(device) |
| 29 | |
| 30 | def transform(self, img): |
| 31 | patches = F.conv2d(img, self.w, padding=3, bias=None) |
| 32 | transf = patches - img |
| 33 | transf_norm = transf / torch.sqrt(0.81 + transf**2) |
| 34 | return transf_norm |
| 35 | |
| 36 | def rgb2gray(self, rgb): |
| 37 | r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] |
| 38 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b |
| 39 | return gray |
| 40 | |
| 41 | def hamming(self, t1, t2): |
| 42 | dist = (t1 - t2) ** 2 |
| 43 | dist_norm = torch.mean(dist / (0.1 + dist), 1, True) |
| 44 | return dist_norm |
| 45 | |
| 46 | def valid_mask(self, t, padding): |
| 47 | n, _, h, w = t.size() |
| 48 | inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) |
| 49 | mask = F.pad(inner, [padding] * 4) |
| 50 | return mask |
| 51 | |
| 52 | def forward(self, img0, img1): |
| 53 | img0 = self.transform(self.rgb2gray(img0)) |
| 54 | img1 = self.transform(self.rgb2gray(img1)) |
| 55 | return self.hamming(img0, img1) * self.valid_mask(img0, 1) |
| 56 | |
| 57 | |
| 58 | class SOBEL(nn.Module): |