MCPcopy
hub / github.com/vladmandic/sdnext / diffusers_callback

Function diffusers_callback

modules/processing_callbacks.py:54–234  ·  view source on GitHub ↗
(pipe, step: int = 0, timestep: int = 0, kwargs: dict | None = None)

Source from the content-addressed store, hash-verified

52
53
54def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict | None = None):
55 if kwargs is None:
56 kwargs = {}
57 t0 = time.time()
58 if devices.backend == "ipex":
59 torch.xpu.synchronize(devices.device)
60 elif devices.backend in {"cuda", "zluda", "rocm"}:
61 torch.cuda.synchronize(devices.device)
62
63 if shared.state.paused:
64 log.debug('Sampling paused')
65 while shared.state.paused:
66 if shared.state.interrupted or shared.state.skipped:
67 raise AssertionError('Interrupted...')
68 time.sleep(0.1)
69
70 image = kwargs.get('image', None)
71 if image is not None:
72 shared.state.current_image = image
73 shared.state.current_latent = None
74 shared.state.step() # increase step
75 shared.state.preview_job = -1 # indicate that preview image has changed
76 debug_callback(f'Callback: step={step} timestep={timestep} image={image if image is not None else None} kwargs={list(kwargs)}')
77 return kwargs
78
79 latents = kwargs.get('latents', None)
80 if debug:
81 debug_callback(f'Callback: step={step} timestep={timestep} latents={latents.shape if latents is not None else None} kwargs={list(kwargs)}')
82 if shared.state.sampling_steps == 0 and getattr(pipe, 'num_timesteps', 0) > 0:
83 shared.state.sampling_steps = pipe.num_timesteps
84 shared.state.step()
85 if shared.state.interrupted or shared.state.skipped:
86 raise AssertionError('Interrupted...')
87 if latents is None:
88 return kwargs
89 elif shared.opts.nan_skip:
90 assert not torch.isnan(latents[..., 0, 0]).all(), f'NaN detected at step {step}: Skipping...'
91 if p is None:
92 return kwargs
93 if len(getattr(p, 'ip_adapter_names', [])) > 0 and p.ip_adapter_names[0] != 'None':
94 ip_adapter_scales = list(p.ip_adapter_scales)
95 ip_adapter_starts = list(p.ip_adapter_starts)
96 ip_adapter_ends = list(p.ip_adapter_ends)
97 if any(end != 1 for end in ip_adapter_ends) or any(start != 0 for start in ip_adapter_starts):
98 if 'Flux' in pipe.__class__.__name__:
99 ip_adapter_scales = [(ip_adapter_starts[0] + (ip_adapter_ends[0] - ip_adapter_starts[0]) * (i / (19 - 1))) for i in range(19)]
100 else:
101 for i in range(len(ip_adapter_scales)):
102 ip_adapter_scales[i] *= float(step >= pipe.num_timesteps * ip_adapter_starts[i])
103 ip_adapter_scales[i] *= float(step <= pipe.num_timesteps * ip_adapter_ends[i])
104 debug_callback(f"Callback: IP Adapter scales={ip_adapter_scales}")
105 pipe.set_ip_adapter_scale(ip_adapter_scales)
106 if step != getattr(pipe, 'num_timesteps', 0):
107 kwargs = processing_correction.correction_callback(p, timestep, kwargs, pipe=pipe, initial=step == 0, step=step)
108 kwargs = prompt_callback(step, kwargs) # monkey patch for diffusers callback issues
109
110 if step == 0:
111 pipe._cfg_end_applied = False # pylint: disable=protected-access

Callers

nothing calls this directly

Calls 9

prompt_callbackFunction · 0.85
set_ip_adapter_scaleMethod · 0.80
_unpatchify_latentsMethod · 0.80
viewMethod · 0.80
getMethod · 0.45
stepMethod · 0.45
_unpack_latentsMethod · 0.45
toMethod · 0.45
addMethod · 0.45

Tested by

no test coverage detected