Container mainly used for lists of torch tensors. Extends lists with pytorch functionality.
| 37 | |
| 38 | |
| 39 | class TensorList(list): |
| 40 | """Container mainly used for lists of torch tensors. Extends lists with pytorch functionality.""" |
| 41 | |
| 42 | def __init__(self, list_of_tensors = None): |
| 43 | if list_of_tensors is None: |
| 44 | list_of_tensors = list() |
| 45 | super(TensorList, self).__init__(list_of_tensors) |
| 46 | |
| 47 | def __deepcopy__(self, memodict={}): |
| 48 | return TensorList(copy.deepcopy(list(self), memodict)) |
| 49 | |
| 50 | def __getitem__(self, item): |
| 51 | if isinstance(item, int): |
| 52 | return super(TensorList, self).__getitem__(item) |
| 53 | elif isinstance(item, (tuple, list)): |
| 54 | return TensorList([super(TensorList, self).__getitem__(i) for i in item]) |
| 55 | else: |
| 56 | return TensorList(super(TensorList, self).__getitem__(item)) |
| 57 | |
| 58 | def __add__(self, other): |
| 59 | if TensorList._iterable(other): |
| 60 | return TensorList([e1 + e2 for e1, e2 in zip(self, other)]) |
| 61 | return TensorList([e + other for e in self]) |
| 62 | |
| 63 | def __radd__(self, other): |
| 64 | if TensorList._iterable(other): |
| 65 | return TensorList([e2 + e1 for e1, e2 in zip(self, other)]) |
| 66 | return TensorList([other + e for e in self]) |
| 67 | |
| 68 | def __iadd__(self, other): |
| 69 | if TensorList._iterable(other): |
| 70 | for i, e2 in enumerate(other): |
| 71 | self[i] += e2 |
| 72 | else: |
| 73 | for i in range(len(self)): |
| 74 | self[i] += other |
| 75 | return self |
| 76 | |
| 77 | def __sub__(self, other): |
| 78 | if TensorList._iterable(other): |
| 79 | return TensorList([e1 - e2 for e1, e2 in zip(self, other)]) |
| 80 | return TensorList([e - other for e in self]) |
| 81 | |
| 82 | def __rsub__(self, other): |
| 83 | if TensorList._iterable(other): |
| 84 | return TensorList([e2 - e1 for e1, e2 in zip(self, other)]) |
| 85 | return TensorList([other - e for e in self]) |
| 86 | |
| 87 | def __isub__(self, other): |
| 88 | if TensorList._iterable(other): |
| 89 | for i, e2 in enumerate(other): |
| 90 | self[i] -= e2 |
| 91 | else: |
| 92 | for i in range(len(self)): |
| 93 | self[i] -= other |
| 94 | return self |
| 95 | |
| 96 | def __mul__(self, other): |
no outgoing calls
no test coverage detected