Plot a performance metric vs. forecast horizon from cross validation. Cross validation produces a collection of out-of-sample model predictions that can be compared to actual values, at a range of different horizons (distance from the cutoff). This computes a specified performance metri
(
df_cv, metric, rolling_window=0.1, ax=None, figsize=(10, 6), color='b',
point_color='gray'
)
| 476 | |
| 477 | |
| 478 | def plot_cross_validation_metric( |
| 479 | df_cv, metric, rolling_window=0.1, ax=None, figsize=(10, 6), color='b', |
| 480 | point_color='gray' |
| 481 | ): |
| 482 | """Plot a performance metric vs. forecast horizon from cross validation. |
| 483 | |
| 484 | Cross validation produces a collection of out-of-sample model predictions |
| 485 | that can be compared to actual values, at a range of different horizons |
| 486 | (distance from the cutoff). This computes a specified performance metric |
| 487 | for each prediction, and aggregated over a rolling window with horizon. |
| 488 | |
| 489 | This uses prophet.diagnostics.performance_metrics to compute the metrics. |
| 490 | Valid values of metric are 'mse', 'rmse', 'mae', 'mape', 'mdape', 'smape', and 'coverage'. |
| 491 | |
| 492 | rolling_window is the proportion of data included in the rolling window of |
| 493 | aggregation. The default value of 0.1 means 10% of data are included in the |
| 494 | aggregation for computing the metric. |
| 495 | |
| 496 | As a concrete example, if metric='mse', then this plot will show the |
| 497 | squared error for each cross validation prediction, along with the MSE |
| 498 | averaged over rolling windows of 10% of the data. |
| 499 | |
| 500 | Parameters |
| 501 | ---------- |
| 502 | df_cv: The output from prophet.diagnostics.cross_validation. |
| 503 | metric: Metric name, one of ['mse', 'rmse', 'mae', 'mape', 'mdape', 'smape', 'coverage']. |
| 504 | rolling_window: Proportion of data to use for rolling average of metric. |
| 505 | In [0, 1]. Defaults to 0.1. |
| 506 | ax: Optional matplotlib axis on which to plot. If not given, a new figure |
| 507 | will be created. |
| 508 | figsize: Optional tuple width, height in inches. |
| 509 | color: Optional color for plot and error points, useful when plotting |
| 510 | multiple model performances on one axis for comparison. |
| 511 | |
| 512 | Returns |
| 513 | ------- |
| 514 | a matplotlib figure. |
| 515 | """ |
| 516 | if ax is None: |
| 517 | fig = plt.figure(facecolor='w', figsize=figsize) |
| 518 | ax = fig.add_subplot(111) |
| 519 | else: |
| 520 | fig = ax.get_figure() |
| 521 | # Get the metric at the level of individual predictions, and with the rolling window. |
| 522 | df_none = performance_metrics(df_cv, metrics=[metric], rolling_window=-1) |
| 523 | df_h = performance_metrics(df_cv, metrics=[metric], rolling_window=rolling_window) |
| 524 | |
| 525 | # Some work because matplotlib does not handle timedelta |
| 526 | # Target ~10 ticks. |
| 527 | tick_w = max(df_none['horizon'].astype('timedelta64[ns]')) / 10. |
| 528 | # Find the largest time resolution that has <1 unit per bin. |
| 529 | dts = ['D', 'h', 'm', 's', 'ms', 'us', 'ns'] |
| 530 | dt_names = [ |
| 531 | 'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds', |
| 532 | 'nanoseconds' |
| 533 | ] |
| 534 | dt_conversions = [ |
| 535 | 24 * 60 * 60 * 10 ** 9, |
nothing calls this directly
no test coverage detected
searching dependent graphs…