(ctx, grad_output)
| 49 | |
| 50 | @staticmethod |
| 51 | def backward(ctx, grad_output): |
| 52 | input, weight = ctx.saved_tensors |
| 53 | grad_input = grad_weight = None |
| 54 | #if ctx.needs_input_grad[0]: |
| 55 | grad_input = grad_output.mm(weight) |
| 56 | #if ctx.needs_input_grad[1]: |
| 57 | grad_weight = grad_output.t().mm(input) |
| 58 | return grad_input, grad_weight |
| 59 | |
| 60 | class myLinear(nn.Module): |
| 61 | def __init__(self, input_features, output_features): |