MCPcopy Index your code
hub / github.com/numpy/numpy / _optimal_path

Function _optimal_path

numpy/_core/einsumfunc.py:150–222  ·  view source on GitHub ↗

Computes all possible pair contractions, sieves the results based on ``memory_limit`` and returns the lowest cost path. This algorithm scales factorial with respect to the elements in the list ``input_sets``. Parameters ---------- input_sets : list List of sets that

(input_sets, output_set, idx_dict, memory_limit)

Source from the content-addressed store, hash-verified

148
149
150def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
151 """
152 Computes all possible pair contractions, sieves the results based
153 on ``memory_limit`` and returns the lowest cost path. This algorithm
154 scales factorial with respect to the elements in the list ``input_sets``.
155
156 Parameters
157 ----------
158 input_sets : list
159 List of sets that represent the lhs side of the einsum subscript
160 output_set : set
161 Set that represents the rhs side of the overall einsum subscript
162 idx_dict : dictionary
163 Dictionary of index sizes
164 memory_limit : int
165 The maximum number of elements in a temporary array
166
167 Returns
168 -------
169 path : list
170 The optimal contraction order within the memory limit constraint.
171
172 Examples
173 --------
174 >>> isets = [set('abd'), set('ac'), set('bdc')]
175 >>> oset = set()
176 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
177 >>> _optimal_path(isets, oset, idx_sizes, 5000)
178 [(0, 2), (0, 1)]
179 """
180
181 full_results = [(0, [], input_sets)]
182 for iteration in range(len(input_sets) - 1):
183 iter_results = []
184
185 # Compute all unique pairs
186 for curr in full_results:
187 cost, positions, remaining = curr
188 for con in itertools.combinations(
189 range(len(input_sets) - iteration), 2
190 ):
191
192 # Find the contraction
193 cont = _find_contraction(con, remaining, output_set)
194 new_result, new_input_sets, idx_removed, idx_contract = cont
195
196 # Sieve the results based on memory_limit
197 new_size = _compute_size_by_dict(new_result, idx_dict)
198 if new_size > memory_limit:
199 continue
200
201 # Build (total_cost, positions, indices_remaining)
202 total_cost = cost + _flop_count(
203 idx_contract, idx_removed, len(con), idx_dict
204 )
205 new_pos = positions + [con]
206 iter_results.append((total_cost, new_pos, new_input_sets))
207

Callers 1

einsum_pathFunction · 0.85

Calls 4

_find_contractionFunction · 0.85
_compute_size_by_dictFunction · 0.85
_flop_countFunction · 0.85
minFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…