MCPcopy
hub / github.com/hzwer/ECCV2022-RIFE / Ternary

Class Ternary

model/loss.py:20–55  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

18
19
20class Ternary(nn.Module):
21 def __init__(self):
22 super(Ternary, self).__init__()
23 patch_size = 7
24 out_channels = patch_size * patch_size
25 self.w = np.eye(out_channels).reshape(
26 (patch_size, patch_size, 1, out_channels))
27 self.w = np.transpose(self.w, (3, 2, 0, 1))
28 self.w = torch.tensor(self.w).float().to(device)
29
30 def transform(self, img):
31 patches = F.conv2d(img, self.w, padding=3, bias=None)
32 transf = patches - img
33 transf_norm = transf / torch.sqrt(0.81 + transf**2)
34 return transf_norm
35
36 def rgb2gray(self, rgb):
37 r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
38 gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
39 return gray
40
41 def hamming(self, t1, t2):
42 dist = (t1 - t2) ** 2
43 dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
44 return dist_norm
45
46 def valid_mask(self, t, padding):
47 n, _, h, w = t.size()
48 inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
49 mask = F.pad(inner, [padding] * 4)
50 return mask
51
52 def forward(self, img0, img1):
53 img0 = self.transform(self.rgb2gray(img0))
54 img1 = self.transform(self.rgb2gray(img1))
55 return self.hamming(img0, img1) * self.valid_mask(img0, 1)
56
57
58class SOBEL(nn.Module):

Callers 3

loss.pyFile · 0.85
__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected