MCPcopy
hub / github.com/yangchris11/samurai / TensorList

Class TensorList

lib/utils/tensor.py:39–217  ·  view source on GitHub ↗

Container mainly used for lists of torch tensors. Extends lists with pytorch functionality.

Source from the content-addressed store, hash-verified

37
38
39class TensorList(list):
40 """Container mainly used for lists of torch tensors. Extends lists with pytorch functionality."""
41
42 def __init__(self, list_of_tensors = None):
43 if list_of_tensors is None:
44 list_of_tensors = list()
45 super(TensorList, self).__init__(list_of_tensors)
46
47 def __deepcopy__(self, memodict={}):
48 return TensorList(copy.deepcopy(list(self), memodict))
49
50 def __getitem__(self, item):
51 if isinstance(item, int):
52 return super(TensorList, self).__getitem__(item)
53 elif isinstance(item, (tuple, list)):
54 return TensorList([super(TensorList, self).__getitem__(i) for i in item])
55 else:
56 return TensorList(super(TensorList, self).__getitem__(item))
57
58 def __add__(self, other):
59 if TensorList._iterable(other):
60 return TensorList([e1 + e2 for e1, e2 in zip(self, other)])
61 return TensorList([e + other for e in self])
62
63 def __radd__(self, other):
64 if TensorList._iterable(other):
65 return TensorList([e2 + e1 for e1, e2 in zip(self, other)])
66 return TensorList([other + e for e in self])
67
68 def __iadd__(self, other):
69 if TensorList._iterable(other):
70 for i, e2 in enumerate(other):
71 self[i] += e2
72 else:
73 for i in range(len(self)):
74 self[i] += other
75 return self
76
77 def __sub__(self, other):
78 if TensorList._iterable(other):
79 return TensorList([e1 - e2 for e1, e2 in zip(self, other)])
80 return TensorList([e - other for e in self])
81
82 def __rsub__(self, other):
83 if TensorList._iterable(other):
84 return TensorList([e2 - e1 for e1, e2 in zip(self, other)])
85 return TensorList([other - e for e in self])
86
87 def __isub__(self, other):
88 if TensorList._iterable(other):
89 for i, e2 in enumerate(other):
90 self[i] -= e2
91 else:
92 for i in range(len(self)):
93 self[i] -= other
94 return self
95
96 def __mul__(self, other):

Callers 15

__init__Method · 0.90
ltr_collateFunction · 0.90
ltr_collate_stack1Function · 0.90
__deepcopy__Method · 0.85
__getitem__Method · 0.85
__add__Method · 0.85
__radd__Method · 0.85
__sub__Method · 0.85
__rsub__Method · 0.85
__mul__Method · 0.85
__rmul__Method · 0.85
__truediv__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected