MCPcopy Index your code
hub / github.com/numpy/numpy / _find_contraction

Function _find_contraction

numpy/_core/einsumfunc.py:90–147  ·  view source on GitHub ↗

Finds the contraction for a given set of input and output sets. Parameters ---------- positions : iterable Integer positions of terms used in the contraction. input_sets : list List of sets that represent the lhs side of the einsum subscript output_set : set

(positions, input_sets, output_set)

Source from the content-addressed store, hash-verified

88
89
90def _find_contraction(positions, input_sets, output_set):
91 """
92 Finds the contraction for a given set of input and output sets.
93
94 Parameters
95 ----------
96 positions : iterable
97 Integer positions of terms used in the contraction.
98 input_sets : list
99 List of sets that represent the lhs side of the einsum subscript
100 output_set : set
101 Set that represents the rhs side of the overall einsum subscript
102
103 Returns
104 -------
105 new_result : set
106 The indices of the resulting contraction
107 remaining : list
108 List of sets that have not been contracted, the new set is appended to
109 the end of this list
110 idx_removed : set
111 Indices removed from the entire contraction
112 idx_contraction : set
113 The indices used in the current contraction
114
115 Examples
116 --------
117
118 # A simple dot product test case
119 >>> pos = (0, 1)
120 >>> isets = [set('ab'), set('bc')]
121 >>> oset = set('ac')
122 >>> _find_contraction(pos, isets, oset)
123 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
124
125 # A more complex case with additional terms in the contraction
126 >>> pos = (0, 2)
127 >>> isets = [set('abd'), set('ac'), set('bdc')]
128 >>> oset = set('ac')
129 >>> _find_contraction(pos, isets, oset)
130 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
131 """
132
133 idx_contract = set()
134 idx_remain = output_set.copy()
135 remaining = []
136 for ind, value in enumerate(input_sets):
137 if ind in positions:
138 idx_contract |= value
139 else:
140 remaining.append(value)
141 idx_remain |= value
142
143 new_result = idx_remain & idx_contract
144 idx_removed = (idx_contract - new_result)
145 remaining.append(new_result)
146
147 return (new_result, remaining, idx_removed, idx_contract)

Callers 4

_optimal_pathFunction · 0.85
_greedy_pathFunction · 0.85
einsum_pathFunction · 0.85

Calls 1

copyMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…