| 74 | |
| 75 | |
| 76 | def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor: |
| 77 | out_list = torch.jit.annotate(List[torch.Tensor], []) |
| 78 | for batch in input_ele: |
| 79 | if len(batch.shape) == 1: |
| 80 | one_batch_padded = F.pad(batch, (0, max_len - batch.size(0)), "constant", 0.0) |
| 81 | else: |
| 82 | one_batch_padded = F.pad(batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0) |
| 83 | out_list.append(one_batch_padded) |
| 84 | out_padded = torch.stack(out_list) |
| 85 | return out_padded |
| 86 | |
| 87 | |
| 88 | def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01): |