| 56 | |
| 57 | |
| 58 | class SOBEL(nn.Module): |
| 59 | def __init__(self): |
| 60 | super(SOBEL, self).__init__() |
| 61 | self.kernelX = torch.tensor([ |
| 62 | [1, 0, -1], |
| 63 | [2, 0, -2], |
| 64 | [1, 0, -1], |
| 65 | ]).float() |
| 66 | self.kernelY = self.kernelX.clone().T |
| 67 | self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device) |
| 68 | self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device) |
| 69 | |
| 70 | def forward(self, pred, gt): |
| 71 | N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3] |
| 72 | img_stack = torch.cat( |
| 73 | [pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0) |
| 74 | sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1) |
| 75 | sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1) |
| 76 | pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:] |
| 77 | pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:] |
| 78 | |
| 79 | L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y) |
| 80 | loss = (L1X+L1Y) |
| 81 | return loss |
| 82 | |
| 83 | class MeanShift(nn.Conv2d): |
| 84 | def __init__(self, data_mean, data_std, data_range=1, norm=True): |