Perform padding for the list of tensors. Args: xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. pad_value (float): Value for padding. Returns: Tensor: Padded tensor (B, Tmax, `*`). Examples: >>> x = [torch.ones(4), torch.ones(2), to
(xs, pad_value)
| 30 | |
| 31 | |
| 32 | def pad_list(xs, pad_value): |
| 33 | """Perform padding for the list of tensors. |
| 34 | |
| 35 | Args: |
| 36 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. |
| 37 | pad_value (float): Value for padding. |
| 38 | |
| 39 | Returns: |
| 40 | Tensor: Padded tensor (B, Tmax, `*`). |
| 41 | |
| 42 | Examples: |
| 43 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] |
| 44 | >>> x |
| 45 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] |
| 46 | >>> pad_list(x, 0) |
| 47 | tensor([[1., 1., 1., 1.], |
| 48 | [1., 1., 0., 0.], |
| 49 | [1., 0., 0., 0.]]) |
| 50 | |
| 51 | """ |
| 52 | n_batch = len(xs) |
| 53 | max_len = max(x.size(0) for x in xs) |
| 54 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) |
| 55 | |
| 56 | for i in range(n_batch): |
| 57 | pad[i, : xs[i].size(0)] = xs[i] |
| 58 | |
| 59 | return pad |
| 60 | |
| 61 | |
| 62 | def pad_list_all_dim(xs, pad_value): |
no outgoing calls
no test coverage detected
searching dependent graphs…