(y, timesteps)
| 328 | |
| 329 | |
| 330 | def timed_adm(y, timesteps): |
| 331 | if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632: |
| 332 | y_mask = (timesteps > 999.0 * (1.0 - float(patch_settings[os.getpid()].adm_scaler_end))).to(y)[..., None] |
| 333 | y_with_adm = y[..., :2816].clone() |
| 334 | y_without_adm = y[..., 2816:].clone() |
| 335 | return y_with_adm * y_mask + y_without_adm * (1.0 - y_mask) |
| 336 | return y |
| 337 | |
| 338 | |
| 339 | def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs): |
no test coverage detected