Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
(weighting_scheme: str, sigmas=None)
| 385 | |
| 386 | |
| 387 | def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): |
| 388 | """ |
| 389 | Computes loss weighting scheme for SD3 training. |
| 390 | |
| 391 | Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. |
| 392 | |
| 393 | SD3 paper reference: https://huggingface.co/papers/2403.03206v1. |
| 394 | """ |
| 395 | if weighting_scheme == "sigma_sqrt": |
| 396 | weighting = (sigmas**-2.0).float() |
| 397 | elif weighting_scheme == "cosmap": |
| 398 | bot = 1 - 2 * sigmas + 2 * sigmas**2 |
| 399 | weighting = 2 / (math.pi * bot) |
| 400 | else: |
| 401 | weighting = torch.ones_like(sigmas) |
| 402 | return weighting |
| 403 | |
| 404 | |
| 405 | def free_memory(): |
no test coverage detected
searching dependent graphs…