Function
get_schedule
(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
)
Source from the content-addressed store, hash-verified
| 287 | |
| 288 | |
| 289 | def get_schedule( |
| 290 | num_steps: int, |
| 291 | image_seq_len: int, |
| 292 | base_shift: float = 0.5, |
| 293 | max_shift: float = 1.15, |
| 294 | shift: bool = True, |
| 295 | ) -> list[float]: |
| 296 | # extra step for zero |
| 297 | timesteps = torch.linspace(1, 0, num_steps + 1) |
| 298 | |
| 299 | # shifting the schedule to favor high timesteps for higher signal images |
| 300 | if shift: |
| 301 | # estimate mu based on linear estimation between two points |
| 302 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) |
| 303 | timesteps = time_shift(mu, 1.0, timesteps) |
| 304 | |
| 305 | return timesteps.tolist() |
| 306 | |
| 307 | |
| 308 | def denoise( |
Tested by
no test coverage detected
Used in the wild real call sites across dependent graphs
searching dependent graphs…