(v0: torch.Tensor, v1: torch.Tensor)
| 44 | |
| 45 | |
| 46 | def project(v0: torch.Tensor, v1: torch.Tensor): |
| 47 | dtype = v0.dtype |
| 48 | v0, v1 = v0.double(), v1.double() |
| 49 | v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) |
| 50 | v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 |
| 51 | v0_orthogonal = v0 - v0_parallel |
| 52 | return v0_parallel.to(dtype), v0_orthogonal.to(dtype) |
| 53 | |
| 54 | |
| 55 | def apg( |