MCPcopy
hub / github.com/zai-org/CogVideo / forward

Method forward

sat/sgm/modules/autoencoding/magvit2_pytorch.py:1509–1751  ·  view source on GitHub ↗
(
        self,
        video_or_images: Tensor,
        cond: Optional[Tensor] = None,
        return_loss=False,
        return_codes=False,
        return_recon=False,
        return_discr_loss=False,
        return_recon_loss_only=False,
        apply_gradient_penalty=True,
        video_contains_first_frame=True,
        adversarial_loss_weight=None,
        multiscale_adversarial_loss_weight=None,
    )

Source from the content-addressed store, hash-verified

1507
1508 @beartype
1509 def forward(
1510 self,
1511 video_or_images: Tensor,
1512 cond: Optional[Tensor] = None,
1513 return_loss=False,
1514 return_codes=False,
1515 return_recon=False,
1516 return_discr_loss=False,
1517 return_recon_loss_only=False,
1518 apply_gradient_penalty=True,
1519 video_contains_first_frame=True,
1520 adversarial_loss_weight=None,
1521 multiscale_adversarial_loss_weight=None,
1522 ):
1523 adversarial_loss_weight = default(adversarial_loss_weight, self.adversarial_loss_weight)
1524 multiscale_adversarial_loss_weight = default(
1525 multiscale_adversarial_loss_weight, self.multiscale_adversarial_loss_weight
1526 )
1527
1528 assert (return_loss + return_codes + return_discr_loss) <= 1
1529 assert video_or_images.ndim in {4, 5}
1530
1531 assert video_or_images.shape[-2:] == (self.image_size, self.image_size)
1532
1533 # accept images for image pretraining (curriculum learning from images to video)
1534
1535 is_image = video_or_images.ndim == 4
1536
1537 if is_image:
1538 video = rearrange(video_or_images, "b c ... -> b c 1 ...")
1539 video_contains_first_frame = True
1540 else:
1541 video = video_or_images
1542
1543 batch, channels, frames = video.shape[:3]
1544
1545 assert divisible_by(
1546 frames - int(video_contains_first_frame), self.time_downsample_factor
1547 ), f"number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}"
1548
1549 # encoder
1550
1551 x = self.encode(video, cond=cond, video_contains_first_frame=video_contains_first_frame)
1552
1553 # lookup free quantization
1554
1555 if self.use_fsq:
1556 quantized, codes = self.quantizers(x)
1557
1558 aux_losses = self.zero
1559 quantizer_loss_breakdown = None
1560 else:
1561 (quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(x, return_loss_breakdown=True)
1562
1563 if return_codes and not return_recon:
1564 return codes
1565
1566 # decoder

Callers 1

tokenizeMethod · 0.95

Calls 10

encodeMethod · 0.95
decodeMethod · 0.95
defaultFunction · 0.70
divisible_byFunction · 0.70
existsFunction · 0.70
pick_video_frameFunction · 0.70
hinge_discr_lossFunction · 0.70
gradient_penaltyFunction · 0.70
grad_layer_wrt_lossFunction · 0.70
hinge_gen_lossFunction · 0.70

Tested by

no test coverage detected