| 94 | return {**metrics} |
| 95 | |
| 96 | def update( |
| 97 | self, |
| 98 | feats_rst: Tensor, |
| 99 | lengths_rst: List[int], |
| 100 | ): |
| 101 | self.count += sum(lengths_rst) |
| 102 | self.count_seq += len(lengths_rst) |
| 103 | |
| 104 | align_idx = np.argsort(lengths_rst)[::-1].copy() |
| 105 | feats_rst = feats_rst[align_idx] |
| 106 | lengths_rst = np.array(lengths_rst)[align_idx] |
| 107 | recmotion_embeddings = self.get_motion_embeddings( |
| 108 | feats_rst, lengths_rst) |
| 109 | cache = [0] * len(lengths_rst) |
| 110 | for i in range(len(lengths_rst)): |
| 111 | cache[align_idx[i]] = recmotion_embeddings[i:i + 1] |
| 112 | |
| 113 | mm_motion_embeddings = torch.cat(cache, axis=0).unsqueeze(0) |
| 114 | # self.mm_motion_embeddings.extend(cache) |
| 115 | # print(mm_motion_embeddings.shape) |
| 116 | # # store all mm motion embeddings |
| 117 | self.mm_motion_embeddings.append(mm_motion_embeddings) |
| 118 | |
| 119 | def get_motion_embeddings(self, feats: Tensor, lengths: List[int]): |
| 120 | m_lens = torch.tensor(lengths) |