MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / __init__

Method __init__

mGPT/models/mgpt.py:22–63  ·  view source on GitHub ↗
(self,
                 cfg,
                 datamodule,
                 lm,
                 motion_vae,
                 codebook_size=512,
                 stage='vae',
                 debug=True,
                 condition='text',
                 task='t2m',
                 metrics_dict=['TM2TMetrics'],
                 **kwargs)

Source from the content-addressed store, hash-verified

20 """
21
22 def __init__(self,
23 cfg,
24 datamodule,
25 lm,
26 motion_vae,
27 codebook_size=512,
28 stage='vae',
29 debug=True,
30 condition='text',
31 task='t2m',
32 metrics_dict=['TM2TMetrics'],
33 **kwargs):
34
35 self.save_hyperparameters(ignore='datamodule', logger=False)
36 self.datamodule = datamodule
37 super().__init__()
38
39 # Instantiate motion tokenizer
40 if motion_vae != None:
41 self.vae = instantiate_from_config(motion_vae)
42
43 # Instantiate motion-language model
44 self.lm = instantiate_from_config(lm)
45
46 # Freeze the motion tokenizer for lm training
47 if 'lm' in self.hparams.stage:
48 self.vae.training = False
49 for p in self.vae.parameters():
50 p.requires_grad = False
51
52 # Instantiate the losses
53 self._losses = torch.nn.ModuleDict({
54 split: GPTLosses(cfg, self.hparams.stage, self.datamodule.njoints)
55 for split in ["losses_train", "losses_test", "losses_val"]
56 })
57
58 # Data transform
59 self.feats2joints = datamodule.feats2joints
60
61 # Count codebook frequency
62 self.codePred = []
63 self.codeFrequency = torch.zeros((self.hparams.codebook_size, ))
64
65 def forward(self, batch, task="t2m"):
66 texts = batch["text"]

Callers

nothing calls this directly

Calls 2

instantiate_from_configFunction · 0.90
GPTLossesClass · 0.90

Tested by

no test coverage detected