(*codes)
| 21 | from .vqvae_tokenizer import sqrt_int |
| 22 | |
| 23 | def concat_codes(*codes): |
| 24 | is_numpy = is_tensor = False |
| 25 | for code in codes: |
| 26 | if isinstance(code, np.ndarray): |
| 27 | is_numpy = True |
| 28 | if isinstance(code, torch.Tensor): |
| 29 | is_tensor = True |
| 30 | device = code.device |
| 31 | if is_tensor: |
| 32 | return torch.cat( |
| 33 | [ |
| 34 | torch.tensor(code, device=device) |
| 35 | for code in codes |
| 36 | ] |
| 37 | ) |
| 38 | elif is_numpy: |
| 39 | return np.concatenate( |
| 40 | [ |
| 41 | np.array(code) |
| 42 | for code in codes |
| 43 | ], |
| 44 | axis=0 |
| 45 | ) |
| 46 | else: |
| 47 | ret = [] |
| 48 | for code in codes: |
| 49 | ret = ret + code |
| 50 | return ret |
| 51 | |
| 52 | def TextCodeTemplate(text, code): |
| 53 | tokenizer = get_tokenizer() |
no outgoing calls
no test coverage detected