(self,
cfg,
datamodule,
lm,
motion_vae,
codebook_size=512,
stage='vae',
debug=True,
condition='text',
task='t2m',
metrics_dict=['TM2TMetrics'],
**kwargs)
| 20 | """ |
| 21 | |
| 22 | def __init__(self, |
| 23 | cfg, |
| 24 | datamodule, |
| 25 | lm, |
| 26 | motion_vae, |
| 27 | codebook_size=512, |
| 28 | stage='vae', |
| 29 | debug=True, |
| 30 | condition='text', |
| 31 | task='t2m', |
| 32 | metrics_dict=['TM2TMetrics'], |
| 33 | **kwargs): |
| 34 | |
| 35 | self.save_hyperparameters(ignore='datamodule', logger=False) |
| 36 | self.datamodule = datamodule |
| 37 | super().__init__() |
| 38 | |
| 39 | # Instantiate motion tokenizer |
| 40 | if motion_vae != None: |
| 41 | self.vae = instantiate_from_config(motion_vae) |
| 42 | |
| 43 | # Instantiate motion-language model |
| 44 | self.lm = instantiate_from_config(lm) |
| 45 | |
| 46 | # Freeze the motion tokenizer for lm training |
| 47 | if 'lm' in self.hparams.stage: |
| 48 | self.vae.training = False |
| 49 | for p in self.vae.parameters(): |
| 50 | p.requires_grad = False |
| 51 | |
| 52 | # Instantiate the losses |
| 53 | self._losses = torch.nn.ModuleDict({ |
| 54 | split: GPTLosses(cfg, self.hparams.stage, self.datamodule.njoints) |
| 55 | for split in ["losses_train", "losses_test", "losses_val"] |
| 56 | }) |
| 57 | |
| 58 | # Data transform |
| 59 | self.feats2joints = datamodule.feats2joints |
| 60 | |
| 61 | # Count codebook frequency |
| 62 | self.codePred = [] |
| 63 | self.codeFrequency = torch.zeros((self.hparams.codebook_size, )) |
| 64 | |
| 65 | def forward(self, batch, task="t2m"): |
| 66 | texts = batch["text"] |
nothing calls this directly
no test coverage detected