(batch)
| 28 | return canvas |
| 29 | |
| 30 | def humanml3d_collate(batch): |
| 31 | notnone_batches = [b for b in batch if b is not None] |
| 32 | EvalFlag = False if notnone_batches[0][5] is None else True |
| 33 | |
| 34 | # Sort by text length |
| 35 | if EvalFlag: |
| 36 | notnone_batches.sort(key=lambda x: x[5], reverse=True) |
| 37 | |
| 38 | # Motion only |
| 39 | adapted_batch = { |
| 40 | "motion": |
| 41 | collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]), |
| 42 | "length": [b[2] for b in notnone_batches], |
| 43 | } |
| 44 | |
| 45 | # Text and motion |
| 46 | if notnone_batches[0][0] is not None: |
| 47 | adapted_batch.update({ |
| 48 | "text": [b[0] for b in notnone_batches], |
| 49 | "all_captions": [b[7] for b in notnone_batches], |
| 50 | }) |
| 51 | |
| 52 | # Evaluation related |
| 53 | if EvalFlag: |
| 54 | adapted_batch.update({ |
| 55 | "text": [b[0] for b in notnone_batches], |
| 56 | "word_embs": |
| 57 | collate_tensors( |
| 58 | [torch.tensor(b[3]).float() for b in notnone_batches]), |
| 59 | "pos_ohot": |
| 60 | collate_tensors( |
| 61 | [torch.tensor(b[4]).float() for b in notnone_batches]), |
| 62 | "text_len": |
| 63 | collate_tensors([torch.tensor(b[5]) for b in notnone_batches]), |
| 64 | "tokens": [b[6] for b in notnone_batches], |
| 65 | }) |
| 66 | |
| 67 | # Tasks |
| 68 | if len(notnone_batches[0]) == 9: |
| 69 | adapted_batch.update({"tasks": [b[8] for b in notnone_batches]}) |
| 70 | |
| 71 | return adapted_batch |
| 72 | |
| 73 | |
| 74 | def load_pkl(path, description=None, progressBar=False): |
nothing calls this directly
no test coverage detected