MCPcopy
hub / github.com/huggingface/diffusers / compute_loss_weighting_for_sd3

Function compute_loss_weighting_for_sd3

src/diffusers/training_utils.py:387–402  ·  view source on GitHub ↗

Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. SD3 paper reference: https://huggingface.co/papers/2403.03206v1.

(weighting_scheme: str, sigmas=None)

Source from the content-addressed store, hash-verified

385
386
387def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
388 """
389 Computes loss weighting scheme for SD3 training.
390
391 Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
392
393 SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
394 """
395 if weighting_scheme == "sigma_sqrt":
396 weighting = (sigmas**-2.0).float()
397 elif weighting_scheme == "cosmap":
398 bot = 1 - 2 * sigmas + 2 * sigmas**2
399 weighting = 2 / (math.pi * bot)
400 else:
401 weighting = torch.ones_like(sigmas)
402 return weighting
403
404
405def free_memory():

Callers 15

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 1

floatMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…