(pipe, step: int = 0, timestep: int = 0, kwargs: dict | None = None)
| 52 | |
| 53 | |
| 54 | def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict | None = None): |
| 55 | if kwargs is None: |
| 56 | kwargs = {} |
| 57 | t0 = time.time() |
| 58 | if devices.backend == "ipex": |
| 59 | torch.xpu.synchronize(devices.device) |
| 60 | elif devices.backend in {"cuda", "zluda", "rocm"}: |
| 61 | torch.cuda.synchronize(devices.device) |
| 62 | |
| 63 | if shared.state.paused: |
| 64 | log.debug('Sampling paused') |
| 65 | while shared.state.paused: |
| 66 | if shared.state.interrupted or shared.state.skipped: |
| 67 | raise AssertionError('Interrupted...') |
| 68 | time.sleep(0.1) |
| 69 | |
| 70 | image = kwargs.get('image', None) |
| 71 | if image is not None: |
| 72 | shared.state.current_image = image |
| 73 | shared.state.current_latent = None |
| 74 | shared.state.step() # increase step |
| 75 | shared.state.preview_job = -1 # indicate that preview image has changed |
| 76 | debug_callback(f'Callback: step={step} timestep={timestep} image={image if image is not None else None} kwargs={list(kwargs)}') |
| 77 | return kwargs |
| 78 | |
| 79 | latents = kwargs.get('latents', None) |
| 80 | if debug: |
| 81 | debug_callback(f'Callback: step={step} timestep={timestep} latents={latents.shape if latents is not None else None} kwargs={list(kwargs)}') |
| 82 | if shared.state.sampling_steps == 0 and getattr(pipe, 'num_timesteps', 0) > 0: |
| 83 | shared.state.sampling_steps = pipe.num_timesteps |
| 84 | shared.state.step() |
| 85 | if shared.state.interrupted or shared.state.skipped: |
| 86 | raise AssertionError('Interrupted...') |
| 87 | if latents is None: |
| 88 | return kwargs |
| 89 | elif shared.opts.nan_skip: |
| 90 | assert not torch.isnan(latents[..., 0, 0]).all(), f'NaN detected at step {step}: Skipping...' |
| 91 | if p is None: |
| 92 | return kwargs |
| 93 | if len(getattr(p, 'ip_adapter_names', [])) > 0 and p.ip_adapter_names[0] != 'None': |
| 94 | ip_adapter_scales = list(p.ip_adapter_scales) |
| 95 | ip_adapter_starts = list(p.ip_adapter_starts) |
| 96 | ip_adapter_ends = list(p.ip_adapter_ends) |
| 97 | if any(end != 1 for end in ip_adapter_ends) or any(start != 0 for start in ip_adapter_starts): |
| 98 | if 'Flux' in pipe.__class__.__name__: |
| 99 | ip_adapter_scales = [(ip_adapter_starts[0] + (ip_adapter_ends[0] - ip_adapter_starts[0]) * (i / (19 - 1))) for i in range(19)] |
| 100 | else: |
| 101 | for i in range(len(ip_adapter_scales)): |
| 102 | ip_adapter_scales[i] *= float(step >= pipe.num_timesteps * ip_adapter_starts[i]) |
| 103 | ip_adapter_scales[i] *= float(step <= pipe.num_timesteps * ip_adapter_ends[i]) |
| 104 | debug_callback(f"Callback: IP Adapter scales={ip_adapter_scales}") |
| 105 | pipe.set_ip_adapter_scale(ip_adapter_scales) |
| 106 | if step != getattr(pipe, 'num_timesteps', 0): |
| 107 | kwargs = processing_correction.correction_callback(p, timestep, kwargs, pipe=pipe, initial=step == 0, step=step) |
| 108 | kwargs = prompt_callback(step, kwargs) # monkey patch for diffusers callback issues |
| 109 | |
| 110 | if step == 0: |
| 111 | pipe._cfg_end_applied = False # pylint: disable=protected-access |
nothing calls this directly
no test coverage detected