(
noise_pred_pos: torch.Tensor,
noise_pred_neg: torch.Tensor,
cfg_scale: float,
eta: float = 1.0,
norm_threshold: float = 0.0,
)
| 53 | |
| 54 | |
| 55 | def apg( |
| 56 | noise_pred_pos: torch.Tensor, |
| 57 | noise_pred_neg: torch.Tensor, |
| 58 | cfg_scale: float, |
| 59 | eta: float = 1.0, |
| 60 | norm_threshold: float = 0.0, |
| 61 | ): |
| 62 | diff = noise_pred_pos - noise_pred_neg |
| 63 | if norm_threshold > 0: |
| 64 | ones = torch.ones_like(diff) |
| 65 | diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) |
| 66 | scale_factor = torch.minimum(ones, norm_threshold / diff_norm) |
| 67 | diff = diff * scale_factor |
| 68 | diff_parallel, diff_orthogonal = project(diff, noise_pred_pos) |
| 69 | normalized_update = diff_orthogonal + eta * diff_parallel |
| 70 | noise_pred = noise_pred_pos + (cfg_scale - 1) * normalized_update |
| 71 | return noise_pred |
| 72 | |
| 73 | |
| 74 | @comfy_node(name="LTXVApplySTG") |
no test coverage detected