MCPcopy
hub / github.com/Wan-Video/Wan2.1 / set_timesteps

Method set_timesteps

wan/utils/fm_solvers.py:228–291  ·  view source on GitHub ↗

Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): Total number of the spacing of the time steps. device (`str` or `torch.device`, *optional*): The device t

(
        self,
        num_inference_steps: Union[int, None] = None,
        device: Union[str, torch.device] = None,
        sigmas: Optional[List[float]] = None,
        mu: Optional[Union[float, None]] = None,
        shift: Optional[Union[float, None]] = None,
    )

Source from the content-addressed store, hash-verified

226
227 # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
228 def set_timesteps(
229 self,
230 num_inference_steps: Union[int, None] = None,
231 device: Union[str, torch.device] = None,
232 sigmas: Optional[List[float]] = None,
233 mu: Optional[Union[float, None]] = None,
234 shift: Optional[Union[float, None]] = None,
235 ):
236 """
237 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
238 Args:
239 num_inference_steps (`int`):
240 Total number of the spacing of the time steps.
241 device (`str` or `torch.device`, *optional*):
242 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
243 """
244
245 if self.config.use_dynamic_shifting and mu is None:
246 raise ValueError(
247 " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
248 )
249
250 if sigmas is None:
251 sigmas = np.linspace(self.sigma_max, self.sigma_min,
252 num_inference_steps +
253 1).copy()[:-1] # pyright: ignore
254
255 if self.config.use_dynamic_shifting:
256 sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
257 else:
258 if shift is None:
259 shift = self.config.shift
260 sigmas = shift * sigmas / (1 +
261 (shift - 1) * sigmas) # pyright: ignore
262
263 if self.config.final_sigmas_type == "sigma_min":
264 sigma_last = ((1 - self.alphas_cumprod[0]) /
265 self.alphas_cumprod[0])**0.5
266 elif self.config.final_sigmas_type == "zero":
267 sigma_last = 0
268 else:
269 raise ValueError(
270 f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
271 )
272
273 timesteps = sigmas * self.config.num_train_timesteps
274 sigmas = np.concatenate([sigmas, [sigma_last]
275 ]).astype(np.float32) # pyright: ignore
276
277 self.sigmas = torch.from_numpy(sigmas)
278 self.timesteps = torch.from_numpy(timesteps).to(
279 device=device, dtype=torch.int64)
280
281 self.num_inference_steps = len(timesteps)
282
283 self.model_outputs = [
284 None,
285 ] * self.config.solver_order

Callers 6

generateMethod · 0.95
generateMethod · 0.95
generateMethod · 0.95
generateMethod · 0.95
mp_workerMethod · 0.95
retrieve_timestepsFunction · 0.45

Calls 1

time_shiftMethod · 0.95

Tested by

no test coverage detected