| 710 | |
| 711 | // TODO(cais): Add optional arg `weights` to the following function. |
| 712 | const handleMetrics = (metrics: Array<string|LossOrMetricFn>) => { |
| 713 | const metricNamePrefix = ''; |
| 714 | let metricName: string; |
| 715 | let accFn: LossOrMetricFn; |
| 716 | let weightedMetricFn: LossOrMetricFn; |
| 717 | // TODO(cais): Use 'weights_' for weighted metrics. |
| 718 | |
| 719 | for (const metric of metrics) { |
| 720 | if (typeof metric === 'string' && |
| 721 | ['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !== |
| 722 | -1) { |
| 723 | const outputShape = this.internalOutputShapes[i]; |
| 724 | |
| 725 | if (outputShape[outputShape.length - 1] === 1 || |
| 726 | this.lossFunctions[i] === losses.binaryCrossentropy) { |
| 727 | // case: binary accuracy/crossentropy. |
| 728 | if (['accuracy', 'acc'].indexOf(metric) !== -1) { |
| 729 | accFn = Metrics.binaryAccuracy; |
| 730 | } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { |
| 731 | accFn = Metrics.binaryCrossentropy; |
| 732 | } |
| 733 | } else if ( |
| 734 | this.lossFunctions[i] === |
| 735 | losses.sparseCategoricalCrossentropy) { |
| 736 | // case: categorical accuracy / crossentropy with sparse |
| 737 | // targets. |
| 738 | if (['accuracy', 'acc'].indexOf(metric) !== -1) { |
| 739 | accFn = Metrics.sparseCategoricalAccuracy; |
| 740 | } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { |
| 741 | accFn = Metrics.sparseCategoricalCrossentropy; |
| 742 | } |
| 743 | } else { |
| 744 | // case: categorical accuracy / crossentropy. |
| 745 | if (['accuracy', 'acc'].indexOf(metric) !== -1) { |
| 746 | accFn = Metrics.categoricalAccuracy; |
| 747 | } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { |
| 748 | accFn = Metrics.categoricalCrossentropy; |
| 749 | } |
| 750 | } |
| 751 | let suffix: string; |
| 752 | if (['accuracy', 'acc'].indexOf(metric) !== -1) { |
| 753 | suffix = 'acc'; |
| 754 | } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { |
| 755 | suffix = 'ce'; |
| 756 | } |
| 757 | // TODO(cais): Add weighting actually. |
| 758 | weightedMetricFn = accFn; |
| 759 | metricName = metricNamePrefix + suffix; |
| 760 | } else { |
| 761 | const metricFn = Metrics.get(metric); |
| 762 | // TODO(cais): Add weighting actually. |
| 763 | weightedMetricFn = metricFn; |
| 764 | metricName = |
| 765 | metricNamePrefix + Metrics.getLossOrMetricName(metric); |
| 766 | } |
| 767 | |
| 768 | // TODO(cais): Add weighting and masking to metricResult. |
| 769 | let metricResult: LossOrMetricFn; |