MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / update

Method update

mGPT/metrics/mm.py:96–117  ·  view source on GitHub ↗
(
        self,
        feats_rst: Tensor,
        lengths_rst: List[int],
    )

Source from the content-addressed store, hash-verified

94 return {**metrics}
95
96 def update(
97 self,
98 feats_rst: Tensor,
99 lengths_rst: List[int],
100 ):
101 self.count += sum(lengths_rst)
102 self.count_seq += len(lengths_rst)
103
104 align_idx = np.argsort(lengths_rst)[::-1].copy()
105 feats_rst = feats_rst[align_idx]
106 lengths_rst = np.array(lengths_rst)[align_idx]
107 recmotion_embeddings = self.get_motion_embeddings(
108 feats_rst, lengths_rst)
109 cache = [0] * len(lengths_rst)
110 for i in range(len(lengths_rst)):
111 cache[align_idx[i]] = recmotion_embeddings[i:i + 1]
112
113 mm_motion_embeddings = torch.cat(cache, axis=0).unsqueeze(0)
114 # self.mm_motion_embeddings.extend(cache)
115 # print(mm_motion_embeddings.shape)
116 # # store all mm motion embeddings
117 self.mm_motion_embeddings.append(mm_motion_embeddings)
118
119 def get_motion_embeddings(self, feats: Tensor, lengths: List[int]):
120 m_lens = torch.tensor(lengths)

Callers 14

mainFunction · 0.45
load_motionFunction · 0.45
add_textFunction · 0.45
add_audioFunction · 0.45
add_fileFunction · 0.45
app.pyFile · 0.45
get_module_configFunction · 0.45
getCheckpointCallbackFunction · 0.45
humanml3d_collateFunction · 0.45
get_sample_setMethod · 0.45
allsplit_stepMethod · 0.45
on_train_epoch_endMethod · 0.45

Calls 1

get_motion_embeddingsMethod · 0.95

Tested by 1

mainFunction · 0.36