(ctx, *args)
| 119 | class PushState(torch.autograd.Function): |
| 120 | @staticmethod |
| 121 | def forward(ctx, *args): |
| 122 | args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) |
| 123 | if len(args) == 1: |
| 124 | return args[0] |
| 125 | return args |
| 126 | |
| 127 | @staticmethod |
| 128 | def backward(ctx, *grad_outs): |