(p, timestep, kwargs, pipe=None, initial: bool = False, step: int = 0)
| 212 | |
| 213 | |
| 214 | def correction_callback(p, timestep, kwargs, pipe=None, initial: bool = False, step: int = 0): |
| 215 | if pipe and pipe.__class__.__name__ in ['HiDreamO1Pipeline', 'HiDreamO1ImagePipeline']: |
| 216 | return kwargs |
| 217 | if initial: |
| 218 | if not any([p.hdr_clamp, p.hdr_mode, p.hdr_maximize, p.hdr_sharpen, p.hdr_color, p.hdr_brightness, p.hdr_tint_ratio]): |
| 219 | p.correction_skip = True |
| 220 | return kwargs |
| 221 | # always skip for detailer passes (already-corrected image, different resolution) |
| 222 | if getattr(p, 'recursion', False): |
| 223 | p.correction_skip = True |
| 224 | return kwargs |
| 225 | # optionally skip for hires pass |
| 226 | if getattr(p, 'is_hr_pass', False) and not getattr(p, 'hdr_apply_hires', True): |
| 227 | p.correction_skip = True |
| 228 | return kwargs |
| 229 | p.correction_skip = False |
| 230 | p.correction_warned = False |
| 231 | total = getattr(pipe, 'num_timesteps', 0) if pipe is not None else 0 |
| 232 | if total > 0: |
| 233 | p.correction_total_steps = total |
| 234 | p.correction_steps_mid = max(int(total * 0.5), 1) # 20%-70% range |
| 235 | p.correction_steps_late = max(int(total * 0.2), 1) # last 20% |
| 236 | elif pipe is not None: |
| 237 | p.correction_total_steps = 0 |
| 238 | p.correction_steps_mid = _count_steps_in_range(pipe, 600, 900) |
| 239 | p.correction_steps_late = _count_steps_below(pipe, 200) |
| 240 | elif getattr(p, 'correction_skip', False): |
| 241 | return kwargs |
| 242 | latents = kwargs["latents"] |
| 243 | if debug_enabled: |
| 244 | debug(f'Correction callback: step={step} timestep={timestep} latents_shape={latents.shape} total={getattr(p, "correction_total_steps", "unset")} skip={getattr(p, "correction_skip", "unset")}') |
| 245 | if len(latents.shape) <= 3: # packed latent |
| 246 | if pipe is None: |
| 247 | if not getattr(p, 'correction_warned', False): |
| 248 | log.warning(f'Latent correction: shape={latents.shape} packed latent but no pipe reference') |
| 249 | p.correction_warned = True |
| 250 | return kwargs |
| 251 | unpacked, pack_type = _unpack_latents(latents, pipe, p) |
| 252 | if pack_type == 'unknown': |
| 253 | if not getattr(p, 'correction_warned', False): |
| 254 | log.warning(f'Latent correction: shape={latents.shape} unknown packed format') |
| 255 | p.correction_warned = True |
| 256 | return kwargs |
| 257 | for i in range(unpacked.shape[0]): |
| 258 | unpacked[i] = correction(p, timestep, unpacked[i], step=step) |
| 259 | kwargs["latents"] = _repack_latents(unpacked, pack_type, pipe, p) |
| 260 | elif len(latents.shape) == 4: # standard batched latent |
| 261 | for i in range(latents.shape[0]): |
| 262 | latents[i] = correction(p, timestep, latents[i], step=step) |
| 263 | if debug_enabled: |
| 264 | debug(f"Full Mean: {latents[i].mean().item()}") |
| 265 | debug(f"Channel Means: {latents[i].mean(dim=(-1, -2), keepdim=True).flatten().float().cpu().numpy()}") |
| 266 | debug(f"Channel Mins: {latents[i].min(-1, keepdim=True)[0].min(-2, keepdim=True)[0].flatten().float().cpu().numpy()}") |
| 267 | debug(f"Channel Maxes: {latents[i].max(-1, keepdim=True)[0].min(-2, keepdim=True)[0].flatten().float().cpu().numpy()}") |
| 268 | kwargs["latents"] = latents |
| 269 | elif len(latents.shape) == 5 and latents.shape[0] == 1: # probably animatediff |
| 270 | latents = latents.squeeze(0).permute(1, 0, 2, 3) |
| 271 | for i in range(latents.shape[0]): |
nothing calls this directly
no test coverage detected