Sequential tensor with variable length. Args: feats (torch.Tensor): Features of the varlen tensor. layout (List[slice]): Layout of the varlen tensor for each batch
| 15 | |
| 16 | |
| 17 | class VarLenTensor: |
| 18 | """ |
| 19 | Sequential tensor with variable length. |
| 20 | |
| 21 | Args: |
| 22 | feats (torch.Tensor): Features of the varlen tensor. |
| 23 | layout (List[slice]): Layout of the varlen tensor for each batch |
| 24 | """ |
| 25 | def __init__(self, feats: torch.Tensor, layout: List[slice]=None): |
| 26 | self.feats = feats |
| 27 | self.layout = layout if layout is not None else [slice(0, feats.shape[0])] |
| 28 | self._cache = {} |
| 29 | |
| 30 | @staticmethod |
| 31 | def layout_from_seqlen(seqlen: list) -> List[slice]: |
| 32 | """ |
| 33 | Create a layout from a tensor of sequence lengths. |
| 34 | """ |
| 35 | layout = [] |
| 36 | start = 0 |
| 37 | for l in seqlen: |
| 38 | layout.append(slice(start, start + l)) |
| 39 | start += l |
| 40 | return layout |
| 41 | |
| 42 | @staticmethod |
| 43 | def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': |
| 44 | """ |
| 45 | Create a VarLenTensor from a list of tensors. |
| 46 | """ |
| 47 | feats = torch.cat(tensor_list, dim=0) |
| 48 | layout = [] |
| 49 | start = 0 |
| 50 | for tensor in tensor_list: |
| 51 | layout.append(slice(start, start + tensor.shape[0])) |
| 52 | start += tensor.shape[0] |
| 53 | return VarLenTensor(feats, layout) |
| 54 | |
| 55 | def to_tensor_list(self) -> List[torch.Tensor]: |
| 56 | """ |
| 57 | Convert a VarLenTensor to a list of tensors. |
| 58 | """ |
| 59 | tensor_list = [] |
| 60 | for s in self.layout: |
| 61 | tensor_list.append(self.feats[s]) |
| 62 | return tensor_list |
| 63 | |
| 64 | def __len__(self) -> int: |
| 65 | return len(self.layout) |
| 66 | |
| 67 | @property |
| 68 | def shape(self) -> torch.Size: |
| 69 | return torch.Size([len(self.layout), *self.feats.shape[1:]]) |
| 70 | |
| 71 | def dim(self) -> int: |
| 72 | return len(self.shape) |
| 73 | |
| 74 | @property |
no outgoing calls
no test coverage detected