MCPcopy
hub / github.com/TencentARC/Pixal3D / VarLenTensor

Class VarLenTensor

pixal3d/modules/sparse/basic.py:17–302  ·  view source on GitHub ↗

Sequential tensor with variable length. Args: feats (torch.Tensor): Features of the varlen tensor. layout (List[slice]): Layout of the varlen tensor for each batch

Source from the content-addressed store, hash-verified

15
16
17class VarLenTensor:
18 """
19 Sequential tensor with variable length.
20
21 Args:
22 feats (torch.Tensor): Features of the varlen tensor.
23 layout (List[slice]): Layout of the varlen tensor for each batch
24 """
25 def __init__(self, feats: torch.Tensor, layout: List[slice]=None):
26 self.feats = feats
27 self.layout = layout if layout is not None else [slice(0, feats.shape[0])]
28 self._cache = {}
29
30 @staticmethod
31 def layout_from_seqlen(seqlen: list) -> List[slice]:
32 """
33 Create a layout from a tensor of sequence lengths.
34 """
35 layout = []
36 start = 0
37 for l in seqlen:
38 layout.append(slice(start, start + l))
39 start += l
40 return layout
41
42 @staticmethod
43 def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor':
44 """
45 Create a VarLenTensor from a list of tensors.
46 """
47 feats = torch.cat(tensor_list, dim=0)
48 layout = []
49 start = 0
50 for tensor in tensor_list:
51 layout.append(slice(start, start + tensor.shape[0]))
52 start += tensor.shape[0]
53 return VarLenTensor(feats, layout)
54
55 def to_tensor_list(self) -> List[torch.Tensor]:
56 """
57 Convert a VarLenTensor to a list of tensors.
58 """
59 tensor_list = []
60 for s in self.layout:
61 tensor_list.append(self.feats[s])
62 return tensor_list
63
64 def __len__(self) -> int:
65 return len(self.layout)
66
67 @property
68 def shape(self) -> torch.Size:
69 return torch.Size([len(self.layout), *self.feats.shape[1:]])
70
71 def dim(self) -> int:
72 return len(self.shape)
73
74 @property

Callers 4

from_tensor_listMethod · 0.85
replaceMethod · 0.85
__getitem__Method · 0.85
varlen_catFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected