MCPcopy
hub / github.com/linkedin/greykite / calc_pred_err

Function calc_pred_err

greykite/common/evaluation.py:359–405  ·  view source on GitHub ↗

Calculates the basic error measures in `~greykite.common.evaluation.EvaluationMetricEnum` and returns them in a dictionary. Parameters ---------- y_true : `list` [`float`] or `numpy.array` Observed values. y_pred : `list` [`float`] or `numpy.array` Model pred

(y_true, y_pred)

Source from the content-addressed store, hash-verified

357
358
359def calc_pred_err(y_true, y_pred):
360 """Calculates the basic error measures in
361 `~greykite.common.evaluation.EvaluationMetricEnum`
362 and returns them in a dictionary.
363
364 Parameters
365 ----------
366 y_true : `list` [`float`] or `numpy.array`
367 Observed values.
368 y_pred : `list` [`float`] or `numpy.array`
369 Model predictions.
370
371 Returns
372 -------
373 error : `dict` [`str`, `float` or None]
374 Dictionary mapping
375 `~greykite.common.evaluation.EvaluationMetricEnum`
376 metric names to their values for the given ``y_true``
377 and ``y_pred``.
378
379 The dictionary is empty is there are no finite elements
380 in ``y_true``.
381 """
382 y_true, y_pred = valid_elements_for_evaluation(
383 reference_arrays=[y_true],
384 arrays=[y_pred],
385 reference_array_names="y_true",
386 drop_leading_only=False,
387 keep_inf=False)
388 # The Silverkite Multistage model has NANs at the beginning
389 # when predicting on the training data.
390 # We only drop the leading NANs/infs from ``y_pred``,
391 # since they are not supposed to appear in the middle.
392 y_pred, y_true = valid_elements_for_evaluation(
393 reference_arrays=[y_pred],
394 arrays=[y_true],
395 reference_array_names="y_pred",
396 drop_leading_only=True,
397 keep_inf=True)
398 error = {}
399
400 if len(y_true) > 0:
401 for enum in EvaluationMetricEnum:
402 metric_name = enum.get_metric_name()
403 metric_func = enum.get_metric_func()
404 error.update({metric_name: metric_func(y_true, y_pred)})
405 return error
406
407
408@add_finite_filter_to_scorer

Calls 3

get_metric_nameMethod · 0.45
get_metric_funcMethod · 0.45