MCPcopy Index your code
hub / github.com/hzwer/ECCV2022-RIFE / SOBEL

Class SOBEL

model/loss.py:58–81  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

56
57
58class SOBEL(nn.Module):
59 def __init__(self):
60 super(SOBEL, self).__init__()
61 self.kernelX = torch.tensor([
62 [1, 0, -1],
63 [2, 0, -2],
64 [1, 0, -1],
65 ]).float()
66 self.kernelY = self.kernelX.clone().T
67 self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
68 self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)
69
70 def forward(self, pred, gt):
71 N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
72 img_stack = torch.cat(
73 [pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
74 sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
75 sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
76 pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
77 pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]
78
79 L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
80 loss = (L1X+L1Y)
81 return loss
82
83class MeanShift(nn.Conv2d):
84 def __init__(self, data_mean, data_std, data_range=1, norm=True):

Callers 3

__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected