MCPcopy
hub / github.com/huggingface/diffusers / compute_dream_and_update_latents

Function compute_dream_and_update_latents

src/diffusers/training_utils.py:243–294  ·  view source on GitHub ↗

Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra forward step without gradients. Args:

(
    unet: UNet2DConditionModel,
    noise_scheduler: SchedulerMixin,
    timesteps: torch.Tensor,
    noise: torch.Tensor,
    noisy_latents: torch.Tensor,
    target: torch.Tensor,
    encoder_hidden_states: torch.Tensor,
    dream_detail_preservation: float = 1.0,
)

Source from the content-addressed store, hash-verified

241
242
243def compute_dream_and_update_latents(
244 unet: UNet2DConditionModel,
245 noise_scheduler: SchedulerMixin,
246 timesteps: torch.Tensor,
247 noise: torch.Tensor,
248 noisy_latents: torch.Tensor,
249 target: torch.Tensor,
250 encoder_hidden_states: torch.Tensor,
251 dream_detail_preservation: float = 1.0,
252) -> tuple[torch.Tensor, torch.Tensor]:
253 """
254 Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from
255 https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more
256 efficient and accurate at the cost of an extra forward step without gradients.
257
258 Args:
259 `unet`: The state unet to use to make a prediction.
260 `noise_scheduler`: The noise scheduler used to add noise for the given timestep.
261 `timesteps`: The timesteps for the noise_scheduler to user.
262 `noise`: A tensor of noise in the shape of noisy_latents.
263 `noisy_latents`: Previously noise latents from the training loop.
264 `target`: The ground-truth tensor to predict after eps is removed.
265 `encoder_hidden_states`: Text embeddings from the text model.
266 `dream_detail_preservation`: A float value that indicates detail preservation level.
267 See reference.
268
269 Returns:
270 `tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
271 """
272 alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
273 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
274
275 # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
276 dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
277
278 pred = None
279 with torch.no_grad():
280 pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
281
282 _noisy_latents, _target = (None, None)
283 if noise_scheduler.config.prediction_type == "epsilon":
284 predicted_noise = pred
285 delta_noise = (noise - predicted_noise).detach()
286 delta_noise.mul_(dream_lambda)
287 _noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
288 _target = target.add(delta_noise)
289 elif noise_scheduler.config.prediction_type == "v_prediction":
290 raise NotImplementedError("DREAM has not been implemented for v-prediction")
291 else:
292 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
293
294 return _noisy_latents, _target
295
296
297def unet_lora_state_dict(unet: UNet2DConditionModel) -> dict[str, torch.Tensor]:

Callers 1

mainFunction · 0.90

Calls 3

unetFunction · 0.85
addMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…