| 48 | return pyr |
| 49 | |
| 50 | class LapLoss(torch.nn.Module): |
| 51 | def __init__(self, max_levels=5, channels=3): |
| 52 | super(LapLoss, self).__init__() |
| 53 | self.max_levels = max_levels |
| 54 | self.gauss_kernel = gauss_kernel(channels=channels) |
| 55 | |
| 56 | def forward(self, input, target): |
| 57 | pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels) |
| 58 | pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels) |
| 59 | return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target)) |