Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra forward step without gradients. Args:
(
unet: UNet2DConditionModel,
noise_scheduler: SchedulerMixin,
timesteps: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
target: torch.Tensor,
encoder_hidden_states: torch.Tensor,
dream_detail_preservation: float = 1.0,
)
| 241 | |
| 242 | |
| 243 | def compute_dream_and_update_latents( |
| 244 | unet: UNet2DConditionModel, |
| 245 | noise_scheduler: SchedulerMixin, |
| 246 | timesteps: torch.Tensor, |
| 247 | noise: torch.Tensor, |
| 248 | noisy_latents: torch.Tensor, |
| 249 | target: torch.Tensor, |
| 250 | encoder_hidden_states: torch.Tensor, |
| 251 | dream_detail_preservation: float = 1.0, |
| 252 | ) -> tuple[torch.Tensor, torch.Tensor]: |
| 253 | """ |
| 254 | Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from |
| 255 | https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more |
| 256 | efficient and accurate at the cost of an extra forward step without gradients. |
| 257 | |
| 258 | Args: |
| 259 | `unet`: The state unet to use to make a prediction. |
| 260 | `noise_scheduler`: The noise scheduler used to add noise for the given timestep. |
| 261 | `timesteps`: The timesteps for the noise_scheduler to user. |
| 262 | `noise`: A tensor of noise in the shape of noisy_latents. |
| 263 | `noisy_latents`: Previously noise latents from the training loop. |
| 264 | `target`: The ground-truth tensor to predict after eps is removed. |
| 265 | `encoder_hidden_states`: Text embeddings from the text model. |
| 266 | `dream_detail_preservation`: A float value that indicates detail preservation level. |
| 267 | See reference. |
| 268 | |
| 269 | Returns: |
| 270 | `tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target. |
| 271 | """ |
| 272 | alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None] |
| 273 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| 274 | |
| 275 | # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments. |
| 276 | dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation |
| 277 | |
| 278 | pred = None |
| 279 | with torch.no_grad(): |
| 280 | pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 281 | |
| 282 | _noisy_latents, _target = (None, None) |
| 283 | if noise_scheduler.config.prediction_type == "epsilon": |
| 284 | predicted_noise = pred |
| 285 | delta_noise = (noise - predicted_noise).detach() |
| 286 | delta_noise.mul_(dream_lambda) |
| 287 | _noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise) |
| 288 | _target = target.add(delta_noise) |
| 289 | elif noise_scheduler.config.prediction_type == "v_prediction": |
| 290 | raise NotImplementedError("DREAM has not been implemented for v-prediction") |
| 291 | else: |
| 292 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 293 | |
| 294 | return _noisy_latents, _target |
| 295 | |
| 296 | |
| 297 | def unet_lora_state_dict(unet: UNet2DConditionModel) -> dict[str, torch.Tensor]: |