| 355 | DESCRIPTION = "Dilates a latent by a grid size." |
| 356 | |
| 357 | def dilate_latent( |
| 358 | self, latent: dict, horizontal_scale: int, vertical_scale: int |
| 359 | ) -> tuple: |
| 360 | if horizontal_scale == 1 and vertical_scale == 1: |
| 361 | return (latent,) |
| 362 | |
| 363 | samples = latent["samples"] |
| 364 | mask = latent.get("noise_mask", None) |
| 365 | dilated_shape = samples.shape[:3] + ( |
| 366 | samples.shape[3] * vertical_scale, |
| 367 | samples.shape[4] * horizontal_scale, |
| 368 | ) |
| 369 | |
| 370 | dilated_samples = torch.zeros( |
| 371 | dilated_shape, |
| 372 | device=samples.device, |
| 373 | dtype=samples.dtype, |
| 374 | requires_grad=False, |
| 375 | ) |
| 376 | dilated_samples[..., ::vertical_scale, ::horizontal_scale] = samples |
| 377 | |
| 378 | dilated_mask_shape = ( |
| 379 | dilated_samples.shape[0], |
| 380 | 1, |
| 381 | dilated_samples.shape[2], |
| 382 | dilated_samples.shape[3], |
| 383 | dilated_samples.shape[4], |
| 384 | ) |
| 385 | dilated_mask = torch.full( |
| 386 | dilated_mask_shape, |
| 387 | -1.0, |
| 388 | device=samples.device, |
| 389 | dtype=samples.dtype, |
| 390 | requires_grad=False, |
| 391 | ) |
| 392 | dilated_mask[..., ::vertical_scale, ::horizontal_scale] = ( |
| 393 | mask if mask is not None else 1.0 |
| 394 | ) |
| 395 | latent = {"samples": dilated_samples, "noise_mask": dilated_mask} |
| 396 | |
| 397 | return (latent,) |
| 398 | |
| 399 | |
| 400 | @comfy_node(name="LTXVAddLatentGuide") |