(ctx, *grads)
| 31 | |
| 32 | @staticmethod |
| 33 | def backward(ctx, *grads): |
| 34 | input, = ctx.saved_tensors |
| 35 | grad_out = torch.zeros_like(input) |
| 36 | grad_out[:] = grads[dist.get_rank()] |
| 37 | return grad_out |
| 38 | |
| 39 | |
| 40 | class CrossEntropyLoss(torch.nn.Module): |
no outgoing calls
no test coverage detected