Compute the cost (removed size + flops) and resultant indices for performing the contraction specified by ``positions``. Parameters ---------- positions : tuple of int The locations of the proposed tensors to contract. input_sets : list of sets The indices found
(
positions, input_sets, output_set, idx_dict,
memory_limit, path_cost, naive_cost
)
| 222 | return path |
| 223 | |
| 224 | def _parse_possible_contraction( |
| 225 | positions, input_sets, output_set, idx_dict, |
| 226 | memory_limit, path_cost, naive_cost |
| 227 | ): |
| 228 | """Compute the cost (removed size + flops) and resultant indices for |
| 229 | performing the contraction specified by ``positions``. |
| 230 | |
| 231 | Parameters |
| 232 | ---------- |
| 233 | positions : tuple of int |
| 234 | The locations of the proposed tensors to contract. |
| 235 | input_sets : list of sets |
| 236 | The indices found on each tensors. |
| 237 | output_set : set |
| 238 | The output indices of the expression. |
| 239 | idx_dict : dict |
| 240 | Mapping of each index to its size. |
| 241 | memory_limit : int |
| 242 | The total allowed size for an intermediary tensor. |
| 243 | path_cost : int |
| 244 | The contraction cost so far. |
| 245 | naive_cost : int |
| 246 | The cost of the unoptimized expression. |
| 247 | |
| 248 | Returns |
| 249 | ------- |
| 250 | cost : (int, int) |
| 251 | A tuple containing the size of any indices removed, and the flop cost. |
| 252 | positions : tuple of int |
| 253 | The locations of the proposed tensors to contract. |
| 254 | new_input_sets : list of sets |
| 255 | The resulting new list of indices if this proposed contraction |
| 256 | is performed. |
| 257 | |
| 258 | """ |
| 259 | |
| 260 | # Find the contraction |
| 261 | contract = _find_contraction(positions, input_sets, output_set) |
| 262 | idx_result, new_input_sets, idx_removed, idx_contract = contract |
| 263 | |
| 264 | # Sieve the results based on memory_limit |
| 265 | new_size = _compute_size_by_dict(idx_result, idx_dict) |
| 266 | if new_size > memory_limit: |
| 267 | return None |
| 268 | |
| 269 | # Build sort tuple |
| 270 | old_sizes = ( |
| 271 | _compute_size_by_dict(input_sets[p], idx_dict) for p in positions |
| 272 | ) |
| 273 | removed_size = sum(old_sizes) - new_size |
| 274 | |
| 275 | # NB: removed_size used to be just the size of any removed indices i.e.: |
| 276 | # helpers.compute_size_by_dict(idx_removed, idx_dict) |
| 277 | cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict) |
| 278 | sort = (-removed_size, cost) |
| 279 | |
| 280 | # Sieve based on total cost as well |
| 281 | if (path_cost + cost) > naive_cost: |
no test coverage detected
searching dependent graphs…