| 1556 | CATEGORY = "WanVideoWrapper" |
| 1557 | |
| 1558 | def process(self, num_frames, phantom_cfg_scale, phantom_start_percent, phantom_end_percent, phantom_latent_1, phantom_latent_2=None, phantom_latent_3=None, phantom_latent_4=None, vace_embeds=None): |
| 1559 | samples = phantom_latent_1["samples"].squeeze(0) |
| 1560 | if phantom_latent_2 is not None: |
| 1561 | samples = torch.cat([samples, phantom_latent_2["samples"].squeeze(0)], dim=1) |
| 1562 | if phantom_latent_3 is not None: |
| 1563 | samples = torch.cat([samples, phantom_latent_3["samples"].squeeze(0)], dim=1) |
| 1564 | if phantom_latent_4 is not None: |
| 1565 | samples = torch.cat([samples, phantom_latent_4["samples"].squeeze(0)], dim=1) |
| 1566 | C, T, H, W = samples.shape |
| 1567 | |
| 1568 | log.info(f"Phantom latents shape: {samples.shape}") |
| 1569 | |
| 1570 | target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, |
| 1571 | H * 8 // VAE_STRIDE[1], |
| 1572 | W * 8 // VAE_STRIDE[2]) |
| 1573 | |
| 1574 | embeds = { |
| 1575 | "target_shape": target_shape, |
| 1576 | "num_frames": num_frames, |
| 1577 | "phantom_latents": samples, |
| 1578 | "phantom_cfg_scale": phantom_cfg_scale, |
| 1579 | "phantom_start_percent": phantom_start_percent, |
| 1580 | "phantom_end_percent": phantom_end_percent, |
| 1581 | } |
| 1582 | if vace_embeds is not None: |
| 1583 | vace_input = { |
| 1584 | "vace_context": vace_embeds["vace_context"], |
| 1585 | "vace_scale": vace_embeds["vace_scale"], |
| 1586 | "has_ref": vace_embeds["has_ref"], |
| 1587 | "vace_start_percent": vace_embeds["vace_start_percent"], |
| 1588 | "vace_end_percent": vace_embeds["vace_end_percent"], |
| 1589 | "vace_seq_len": vace_embeds["vace_seq_len"], |
| 1590 | "additional_vace_inputs": vace_embeds["additional_vace_inputs"], |
| 1591 | } |
| 1592 | embeds.update(vace_input) |
| 1593 | |
| 1594 | return (embeds,) |
| 1595 | |
| 1596 | class WanVideoControlEmbeds: |
| 1597 | @classmethod |