| 21 | |
| 22 | |
| 23 | class TensorList(object): |
| 24 | |
| 25 | def __init__(self, tensors): |
| 26 | """ |
| 27 | tensors: a list of torch.Tensor objects. No need to have uniform shape. |
| 28 | """ |
| 29 | assert isinstance(tensors, (list, tuple)) |
| 30 | assert all(isinstance(u, torch.Tensor) for u in tensors) |
| 31 | assert len(set([u.ndim for u in tensors])) == 1 |
| 32 | assert len(set([u.dtype for u in tensors])) == 1 |
| 33 | assert len(set([u.device for u in tensors])) == 1 |
| 34 | self.tensors = tensors |
| 35 | |
| 36 | def to(self, *args, **kwargs): |
| 37 | return TensorList([u.to(*args, **kwargs) for u in self.tensors]) |
| 38 | |
| 39 | def size(self, dim): |
| 40 | assert dim == 0, 'only support get the 0th size' |
| 41 | return len(self.tensors) |
| 42 | |
| 43 | def pow(self, *args, **kwargs): |
| 44 | return TensorList([u.pow(*args, **kwargs) for u in self.tensors]) |
| 45 | |
| 46 | def squeeze(self, dim): |
| 47 | assert dim != 0 |
| 48 | if dim > 0: |
| 49 | dim -= 1 |
| 50 | return TensorList([u.squeeze(dim) for u in self.tensors]) |
| 51 | |
| 52 | def type(self, *args, **kwargs): |
| 53 | return TensorList([u.type(*args, **kwargs) for u in self.tensors]) |
| 54 | |
| 55 | def type_as(self, other): |
| 56 | assert isinstance(other, (torch.Tensor, TensorList)) |
| 57 | if isinstance(other, torch.Tensor): |
| 58 | return TensorList([u.type_as(other) for u in self.tensors]) |
| 59 | else: |
| 60 | return TensorList([u.type(other.dtype) for u in self.tensors]) |
| 61 | |
| 62 | @property |
| 63 | def dtype(self): |
| 64 | return self.tensors[0].dtype |
| 65 | |
| 66 | @property |
| 67 | def device(self): |
| 68 | return self.tensors[0].device |
| 69 | |
| 70 | @property |
| 71 | def ndim(self): |
| 72 | return 1 + self.tensors[0].ndim |
| 73 | |
| 74 | def __getitem__(self, index): |
| 75 | return self.tensors[index] |
| 76 | |
| 77 | def __len__(self): |
| 78 | return len(self.tensors) |
| 79 | |
| 80 | def __add__(self, other): |