Computes the full set of atoms that must be recomputed when a given set of atoms change. This includes: - All downstream atoms that directly or transitively depend on the changed atoms - All upstream atoms that produced inputs consumed by affected atoms, such as
(self, changed_atoms: set[str])
| 352 | logger.warning(f"[DAG] Skipping producer registration for unknown atom {atom_name=}") |
| 353 | |
| 354 | def _get_affected_atoms(self, changed_atoms: set[str]) -> set[str]: |
| 355 | """ |
| 356 | Computes the full set of atoms that must be recomputed when a given set of atoms change. |
| 357 | |
| 358 | This includes: |
| 359 | - All downstream atoms that directly or transitively depend on the changed atoms |
| 360 | - All upstream atoms that produced inputs consumed by affected atoms, such as |
| 361 | producers of recomputed consumers. |
| 362 | |
| 363 | The traversal ensures that if any atom is recomputed, all consumers of its output |
| 364 | are also recomputed, and all producers of its inputs are re-included as well. |
| 365 | |
| 366 | This forward and backward closure ensures correct propagation in the DAG, especially |
| 367 | for side effecting calls like `plot()` that mutate objects used by downstream atoms. |
| 368 | |
| 369 | Args: |
| 370 | changed_atoms (set[str]): Initial set of atoms known to have changed. |
| 371 | |
| 372 | Returns: |
| 373 | set[str]: The full set of atom names that should be recomputed. |
| 374 | """ |
| 375 | affected = set() |
| 376 | queue = list(changed_atoms) |
| 377 | queued = set(queue) |
| 378 | |
| 379 | logger.info(f"[DAG] Starting recompute traversal {changed_atoms=}") |
| 380 | |
| 381 | while queue: |
| 382 | current = queue.pop() |
| 383 | if current in affected: |
| 384 | continue |
| 385 | |
| 386 | affected.add(current) |
| 387 | |
| 388 | # forward: find all consumers |
| 389 | for atom_name, atom in self.atoms.items(): |
| 390 | if current in atom.dependencies and atom_name not in queued: |
| 391 | queue.append(atom_name) |
| 392 | queued.add(atom_name) |
| 393 | |
| 394 | # backward: re-run producers of recomputed consumers |
| 395 | for dep in self.atoms[current].dependencies: |
| 396 | if dep not in queued: |
| 397 | queue.append(dep) |
| 398 | queued.add(dep) |
| 399 | |
| 400 | return affected |
| 401 | |
| 402 | def _validate_dependencies(self): |
| 403 | """ |
no outgoing calls
no test coverage detected