Finds the contraction for a given set of input and output sets. Parameters ---------- positions : iterable Integer positions of terms used in the contraction. input_sets : list List of sets that represent the lhs side of the einsum subscript output_set : set
(positions, input_sets, output_set)
| 88 | |
| 89 | |
| 90 | def _find_contraction(positions, input_sets, output_set): |
| 91 | """ |
| 92 | Finds the contraction for a given set of input and output sets. |
| 93 | |
| 94 | Parameters |
| 95 | ---------- |
| 96 | positions : iterable |
| 97 | Integer positions of terms used in the contraction. |
| 98 | input_sets : list |
| 99 | List of sets that represent the lhs side of the einsum subscript |
| 100 | output_set : set |
| 101 | Set that represents the rhs side of the overall einsum subscript |
| 102 | |
| 103 | Returns |
| 104 | ------- |
| 105 | new_result : set |
| 106 | The indices of the resulting contraction |
| 107 | remaining : list |
| 108 | List of sets that have not been contracted, the new set is appended to |
| 109 | the end of this list |
| 110 | idx_removed : set |
| 111 | Indices removed from the entire contraction |
| 112 | idx_contraction : set |
| 113 | The indices used in the current contraction |
| 114 | |
| 115 | Examples |
| 116 | -------- |
| 117 | |
| 118 | # A simple dot product test case |
| 119 | >>> pos = (0, 1) |
| 120 | >>> isets = [set('ab'), set('bc')] |
| 121 | >>> oset = set('ac') |
| 122 | >>> _find_contraction(pos, isets, oset) |
| 123 | ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) |
| 124 | |
| 125 | # A more complex case with additional terms in the contraction |
| 126 | >>> pos = (0, 2) |
| 127 | >>> isets = [set('abd'), set('ac'), set('bdc')] |
| 128 | >>> oset = set('ac') |
| 129 | >>> _find_contraction(pos, isets, oset) |
| 130 | ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) |
| 131 | """ |
| 132 | |
| 133 | idx_contract = set() |
| 134 | idx_remain = output_set.copy() |
| 135 | remaining = [] |
| 136 | for ind, value in enumerate(input_sets): |
| 137 | if ind in positions: |
| 138 | idx_contract |= value |
| 139 | else: |
| 140 | remaining.append(value) |
| 141 | idx_remain |= value |
| 142 | |
| 143 | new_result = idx_remain & idx_contract |
| 144 | idx_removed = (idx_contract - new_result) |
| 145 | remaining.append(new_result) |
| 146 | |
| 147 | return (new_result, remaining, idx_removed, idx_contract) |
no test coverage detected
searching dependent graphs…