(
self,
video_or_images: Tensor,
cond: Optional[Tensor] = None,
return_loss=False,
return_codes=False,
return_recon=False,
return_discr_loss=False,
return_recon_loss_only=False,
apply_gradient_penalty=True,
video_contains_first_frame=True,
adversarial_loss_weight=None,
multiscale_adversarial_loss_weight=None,
)
| 1507 | |
| 1508 | @beartype |
| 1509 | def forward( |
| 1510 | self, |
| 1511 | video_or_images: Tensor, |
| 1512 | cond: Optional[Tensor] = None, |
| 1513 | return_loss=False, |
| 1514 | return_codes=False, |
| 1515 | return_recon=False, |
| 1516 | return_discr_loss=False, |
| 1517 | return_recon_loss_only=False, |
| 1518 | apply_gradient_penalty=True, |
| 1519 | video_contains_first_frame=True, |
| 1520 | adversarial_loss_weight=None, |
| 1521 | multiscale_adversarial_loss_weight=None, |
| 1522 | ): |
| 1523 | adversarial_loss_weight = default(adversarial_loss_weight, self.adversarial_loss_weight) |
| 1524 | multiscale_adversarial_loss_weight = default( |
| 1525 | multiscale_adversarial_loss_weight, self.multiscale_adversarial_loss_weight |
| 1526 | ) |
| 1527 | |
| 1528 | assert (return_loss + return_codes + return_discr_loss) <= 1 |
| 1529 | assert video_or_images.ndim in {4, 5} |
| 1530 | |
| 1531 | assert video_or_images.shape[-2:] == (self.image_size, self.image_size) |
| 1532 | |
| 1533 | # accept images for image pretraining (curriculum learning from images to video) |
| 1534 | |
| 1535 | is_image = video_or_images.ndim == 4 |
| 1536 | |
| 1537 | if is_image: |
| 1538 | video = rearrange(video_or_images, "b c ... -> b c 1 ...") |
| 1539 | video_contains_first_frame = True |
| 1540 | else: |
| 1541 | video = video_or_images |
| 1542 | |
| 1543 | batch, channels, frames = video.shape[:3] |
| 1544 | |
| 1545 | assert divisible_by( |
| 1546 | frames - int(video_contains_first_frame), self.time_downsample_factor |
| 1547 | ), f"number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}" |
| 1548 | |
| 1549 | # encoder |
| 1550 | |
| 1551 | x = self.encode(video, cond=cond, video_contains_first_frame=video_contains_first_frame) |
| 1552 | |
| 1553 | # lookup free quantization |
| 1554 | |
| 1555 | if self.use_fsq: |
| 1556 | quantized, codes = self.quantizers(x) |
| 1557 | |
| 1558 | aux_losses = self.zero |
| 1559 | quantizer_loss_breakdown = None |
| 1560 | else: |
| 1561 | (quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(x, return_loss_breakdown=True) |
| 1562 | |
| 1563 | if return_codes and not return_recon: |
| 1564 | return codes |
| 1565 | |
| 1566 | # decoder |
no test coverage detected