input (FloatTensor): ``(n, num_classes)``. target (LongTensor): ``(n,)``, the indices of the target classes
(ctx, input, target)
| 9 | @staticmethod |
| 10 | @custom_fwd |
| 11 | def forward(ctx, input, target): |
| 12 | """ |
| 13 | input (FloatTensor): ``(n, num_classes)``. |
| 14 | target (LongTensor): ``(n,)``, the indices of the target classes |
| 15 | """ |
| 16 | input_batch, classes = input.size() |
| 17 | |
| 18 | z_k = input.gather(1, target.unsqueeze(1)).squeeze() |
| 19 | tau_z, support_size = _threshold_and_support(input, dim=1) |
| 20 | support = input > tau_z |
| 21 | x = torch.where( |
| 22 | support, input**2 - tau_z**2, torch.tensor(0.0, device=input.device) |
| 23 | ).sum(dim=1) |
| 24 | ctx.save_for_backward(input, target, tau_z) |
| 25 | # clamping necessary because of numerical errors: loss should be lower |
| 26 | # bounded by zero, but negative values near zero are possible without |
| 27 | # the clamp |
| 28 | return torch.clamp(x / 2 - z_k + 0.5, min=0.0) |
| 29 | |
| 30 | @staticmethod |
| 31 | @custom_bwd |
nothing calls this directly
no test coverage detected