Computes all possible pair contractions, sieves the results based on ``memory_limit`` and returns the lowest cost path. This algorithm scales factorial with respect to the elements in the list ``input_sets``. Parameters ---------- input_sets : list List of sets that
(input_sets, output_set, idx_dict, memory_limit)
| 148 | |
| 149 | |
| 150 | def _optimal_path(input_sets, output_set, idx_dict, memory_limit): |
| 151 | """ |
| 152 | Computes all possible pair contractions, sieves the results based |
| 153 | on ``memory_limit`` and returns the lowest cost path. This algorithm |
| 154 | scales factorial with respect to the elements in the list ``input_sets``. |
| 155 | |
| 156 | Parameters |
| 157 | ---------- |
| 158 | input_sets : list |
| 159 | List of sets that represent the lhs side of the einsum subscript |
| 160 | output_set : set |
| 161 | Set that represents the rhs side of the overall einsum subscript |
| 162 | idx_dict : dictionary |
| 163 | Dictionary of index sizes |
| 164 | memory_limit : int |
| 165 | The maximum number of elements in a temporary array |
| 166 | |
| 167 | Returns |
| 168 | ------- |
| 169 | path : list |
| 170 | The optimal contraction order within the memory limit constraint. |
| 171 | |
| 172 | Examples |
| 173 | -------- |
| 174 | >>> isets = [set('abd'), set('ac'), set('bdc')] |
| 175 | >>> oset = set() |
| 176 | >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} |
| 177 | >>> _optimal_path(isets, oset, idx_sizes, 5000) |
| 178 | [(0, 2), (0, 1)] |
| 179 | """ |
| 180 | |
| 181 | full_results = [(0, [], input_sets)] |
| 182 | for iteration in range(len(input_sets) - 1): |
| 183 | iter_results = [] |
| 184 | |
| 185 | # Compute all unique pairs |
| 186 | for curr in full_results: |
| 187 | cost, positions, remaining = curr |
| 188 | for con in itertools.combinations( |
| 189 | range(len(input_sets) - iteration), 2 |
| 190 | ): |
| 191 | |
| 192 | # Find the contraction |
| 193 | cont = _find_contraction(con, remaining, output_set) |
| 194 | new_result, new_input_sets, idx_removed, idx_contract = cont |
| 195 | |
| 196 | # Sieve the results based on memory_limit |
| 197 | new_size = _compute_size_by_dict(new_result, idx_dict) |
| 198 | if new_size > memory_limit: |
| 199 | continue |
| 200 | |
| 201 | # Build (total_cost, positions, indices_remaining) |
| 202 | total_cost = cost + _flop_count( |
| 203 | idx_contract, idx_removed, len(con), idx_dict |
| 204 | ) |
| 205 | new_pos = positions + [con] |
| 206 | iter_results.append((total_cost, new_pos, new_input_sets)) |
| 207 |
no test coverage detected
searching dependent graphs…