MCPcopy
hub / github.com/tensorflow/tfjs / handleMetrics

Method handleMetrics

tfjs-layers/src/engine/training.ts:712–775  ·  view source on GitHub ↗
(metrics: Array<string|LossOrMetricFn>)

Source from the content-addressed store, hash-verified

710
711 // TODO(cais): Add optional arg `weights` to the following function.
712 const handleMetrics = (metrics: Array<string|LossOrMetricFn>) => {
713 const metricNamePrefix = '';
714 let metricName: string;
715 let accFn: LossOrMetricFn;
716 let weightedMetricFn: LossOrMetricFn;
717 // TODO(cais): Use 'weights_' for weighted metrics.
718
719 for (const metric of metrics) {
720 if (typeof metric === 'string' &&
721 ['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !==
722 -1) {
723 const outputShape = this.internalOutputShapes[i];
724
725 if (outputShape[outputShape.length - 1] === 1 ||
726 this.lossFunctions[i] === losses.binaryCrossentropy) {
727 // case: binary accuracy/crossentropy.
728 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
729 accFn = Metrics.binaryAccuracy;
730 } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
731 accFn = Metrics.binaryCrossentropy;
732 }
733 } else if (
734 this.lossFunctions[i] ===
735 losses.sparseCategoricalCrossentropy) {
736 // case: categorical accuracy / crossentropy with sparse
737 // targets.
738 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
739 accFn = Metrics.sparseCategoricalAccuracy;
740 } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
741 accFn = Metrics.sparseCategoricalCrossentropy;
742 }
743 } else {
744 // case: categorical accuracy / crossentropy.
745 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
746 accFn = Metrics.categoricalAccuracy;
747 } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
748 accFn = Metrics.categoricalCrossentropy;
749 }
750 }
751 let suffix: string;
752 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
753 suffix = 'acc';
754 } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
755 suffix = 'ce';
756 }
757 // TODO(cais): Add weighting actually.
758 weightedMetricFn = accFn;
759 metricName = metricNamePrefix + suffix;
760 } else {
761 const metricFn = Metrics.get(metric);
762 // TODO(cais): Add weighting actually.
763 weightedMetricFn = metricFn;
764 metricName =
765 metricNamePrefix + Metrics.getLossOrMetricName(metric);
766 }
767
768 // TODO(cais): Add weighting and masking to metricResult.
769 let metricResult: LossOrMetricFn;

Callers

nothing calls this directly

Calls 2

nameScopeFunction · 0.90
getMethod · 0.45

Tested by

no test coverage detected