MCPcopy
hub / github.com/vladmandic/sdnext / correction_callback

Function correction_callback

modules/processing_correction.py:214–279  ·  view source on GitHub ↗
(p, timestep, kwargs, pipe=None, initial: bool = False, step: int = 0)

Source from the content-addressed store, hash-verified

212
213
214def correction_callback(p, timestep, kwargs, pipe=None, initial: bool = False, step: int = 0):
215 if pipe and pipe.__class__.__name__ in ['HiDreamO1Pipeline', 'HiDreamO1ImagePipeline']:
216 return kwargs
217 if initial:
218 if not any([p.hdr_clamp, p.hdr_mode, p.hdr_maximize, p.hdr_sharpen, p.hdr_color, p.hdr_brightness, p.hdr_tint_ratio]):
219 p.correction_skip = True
220 return kwargs
221 # always skip for detailer passes (already-corrected image, different resolution)
222 if getattr(p, 'recursion', False):
223 p.correction_skip = True
224 return kwargs
225 # optionally skip for hires pass
226 if getattr(p, 'is_hr_pass', False) and not getattr(p, 'hdr_apply_hires', True):
227 p.correction_skip = True
228 return kwargs
229 p.correction_skip = False
230 p.correction_warned = False
231 total = getattr(pipe, 'num_timesteps', 0) if pipe is not None else 0
232 if total > 0:
233 p.correction_total_steps = total
234 p.correction_steps_mid = max(int(total * 0.5), 1) # 20%-70% range
235 p.correction_steps_late = max(int(total * 0.2), 1) # last 20%
236 elif pipe is not None:
237 p.correction_total_steps = 0
238 p.correction_steps_mid = _count_steps_in_range(pipe, 600, 900)
239 p.correction_steps_late = _count_steps_below(pipe, 200)
240 elif getattr(p, 'correction_skip', False):
241 return kwargs
242 latents = kwargs["latents"]
243 if debug_enabled:
244 debug(f'Correction callback: step={step} timestep={timestep} latents_shape={latents.shape} total={getattr(p, "correction_total_steps", "unset")} skip={getattr(p, "correction_skip", "unset")}')
245 if len(latents.shape) <= 3: # packed latent
246 if pipe is None:
247 if not getattr(p, 'correction_warned', False):
248 log.warning(f'Latent correction: shape={latents.shape} packed latent but no pipe reference')
249 p.correction_warned = True
250 return kwargs
251 unpacked, pack_type = _unpack_latents(latents, pipe, p)
252 if pack_type == 'unknown':
253 if not getattr(p, 'correction_warned', False):
254 log.warning(f'Latent correction: shape={latents.shape} unknown packed format')
255 p.correction_warned = True
256 return kwargs
257 for i in range(unpacked.shape[0]):
258 unpacked[i] = correction(p, timestep, unpacked[i], step=step)
259 kwargs["latents"] = _repack_latents(unpacked, pack_type, pipe, p)
260 elif len(latents.shape) == 4: # standard batched latent
261 for i in range(latents.shape[0]):
262 latents[i] = correction(p, timestep, latents[i], step=step)
263 if debug_enabled:
264 debug(f"Full Mean: {latents[i].mean().item()}")
265 debug(f"Channel Means: {latents[i].mean(dim=(-1, -2), keepdim=True).flatten().float().cpu().numpy()}")
266 debug(f"Channel Mins: {latents[i].min(-1, keepdim=True)[0].min(-2, keepdim=True)[0].flatten().float().cpu().numpy()}")
267 debug(f"Channel Maxes: {latents[i].max(-1, keepdim=True)[0].min(-2, keepdim=True)[0].flatten().float().cpu().numpy()}")
268 kwargs["latents"] = latents
269 elif len(latents.shape) == 5 and latents.shape[0] == 1: # probably animatediff
270 latents = latents.squeeze(0).permute(1, 0, 2, 3)
271 for i in range(latents.shape[0]):

Callers

nothing calls this directly

Calls 7

_count_steps_in_rangeFunction · 0.85
_count_steps_belowFunction · 0.85
_unpack_latentsFunction · 0.85
correctionFunction · 0.85
_repack_latentsFunction · 0.85
debugFunction · 0.50
flattenMethod · 0.45

Tested by

no test coverage detected