Retrieve the inverse functions of an pipeline or an estimator.
(estimator, terminal=True)
| 621 | |
| 622 | |
| 623 | def _get_inverse_funcs(estimator, terminal=True): |
| 624 | """Retrieve the inverse functions of an pipeline or an estimator.""" |
| 625 | inverse_func = list() |
| 626 | estimators = list() |
| 627 | if hasattr(estimator, "steps"): |
| 628 | # if pipeline, retrieve all steps by nesting |
| 629 | for _, est in estimator.steps: |
| 630 | inverse_func.extend(_get_inverse_funcs(est, terminal=False)) |
| 631 | estimators.append(est.__class__.__name__) |
| 632 | elif hasattr(estimator, "inverse_transform"): |
| 633 | # if not pipeline attempt to retrieve inverse function |
| 634 | inverse_func.append(estimator.inverse_transform) |
| 635 | estimators.append(estimator.__class__.__name__) |
| 636 | else: |
| 637 | inverse_func.append(False) |
| 638 | estimators.append("Unknown") |
| 639 | |
| 640 | # If terminal node, check that that the last estimator is a classifier, |
| 641 | # and remove it from the transformers. |
| 642 | if terminal: |
| 643 | last_is_estimator = inverse_func[-1] is False |
| 644 | logger.debug(f" Last estimator is an estimator: {last_is_estimator}") |
| 645 | non_invertible = np.where( |
| 646 | [inv_func is False for inv_func in inverse_func[:-1]] |
| 647 | )[0] |
| 648 | if last_is_estimator and len(non_invertible) == 0: |
| 649 | # keep all inverse transformation and remove last estimation |
| 650 | logger.debug(" Removing inverse transformation from inverse list.") |
| 651 | inverse_func = inverse_func[:-1] |
| 652 | else: |
| 653 | if len(non_invertible): |
| 654 | bad = ", ".join(estimators[ni] for ni in non_invertible) |
| 655 | warn( |
| 656 | f"Cannot inverse transform non-invertible " |
| 657 | f"estimator{_pl(non_invertible)}: {bad}." |
| 658 | ) |
| 659 | inverse_func = list() |
| 660 | |
| 661 | return inverse_func |
| 662 | |
| 663 | |
| 664 | def _get_inverse_funcs_before_step(estimator, step_name): |