| 97 | |
| 98 | @staticmethod |
| 99 | def backward(ctx, *output_grads): |
| 100 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] |
| 101 | with torch.enable_grad(), \ |
| 102 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): |
| 103 | # Fixes a bug where the first op in run_function modifies the |
| 104 | # Tensor storage in place, which is not allowed for detach()'d |
| 105 | # Tensors. |
| 106 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] |
| 107 | output_tensors = ctx.run_function(*shallow_copies) |
| 108 | grads = torch.autograd.grad( |
| 109 | output_tensors, |
| 110 | ctx.input_tensors + [x for x in ctx.input_params if x.requires_grad], |
| 111 | output_grads, |
| 112 | allow_unused=True, |
| 113 | ) |
| 114 | grads = list(grads) |
| 115 | # Assign gradients to the correct positions, matching None for those that do not require gradients |
| 116 | input_grads = [] |
| 117 | for tensor in ctx.input_tensors + ctx.input_params: |
| 118 | if tensor.requires_grad: |
| 119 | input_grads.append(grads.pop(0)) # Get the next computed gradient |
| 120 | else: |
| 121 | input_grads.append(None) # No gradient required for this tensor |
| 122 | del ctx.input_tensors |
| 123 | del ctx.input_params |
| 124 | del output_tensors |
| 125 | return (None, None) + tuple(input_grads) |
| 126 | |
| 127 | |
| 128 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): |