(ctx, grad_output)
| 65 | @staticmethod |
| 66 | @custom_bwd |
| 67 | def backward(ctx, grad_output): |
| 68 | supp_size, output = ctx.saved_tensors |
| 69 | dim = ctx.dim |
| 70 | grad_input = grad_output.clone() |
| 71 | grad_input[output == 0] = 0 |
| 72 | |
| 73 | v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() |
| 74 | v_hat = v_hat.unsqueeze(dim) |
| 75 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) |
| 76 | return grad_input, None |
| 77 | |
| 78 | |
| 79 | sparsemax = SparsemaxFunction.apply |