| 95 | * @param input |
| 96 | */ |
| 97 | export async function tensorStats(input: Tensor): Promise<HistogramStats> { |
| 98 | // TODO. Benchmark this and consider having one of the *stats functions |
| 99 | // delegate to the other. |
| 100 | |
| 101 | const [min, max, numZeros] = tidy(() => { |
| 102 | const zero = scalar(0, input.dtype); |
| 103 | |
| 104 | const min = input.min(); |
| 105 | const max = input.max(); |
| 106 | const numZeros = input.equal(zero).sum(); |
| 107 | |
| 108 | return [min, max, numZeros]; |
| 109 | }); |
| 110 | |
| 111 | return Promise.all([input.data(), min.data(), max.data(), numZeros.data()]) |
| 112 | .then(([tensorVal, minVal, maxVal, numZerosVal]) => { |
| 113 | // We currently need to count NaNs on CPU. |
| 114 | const numVals = tensorVal.length; |
| 115 | let numNans = 0; |
| 116 | let numInfs = 0; |
| 117 | for (let i = 0; i < numVals; i++) { |
| 118 | const curr = tensorVal[i]; |
| 119 | if (isNaN(curr)) { |
| 120 | numNans += 1; |
| 121 | } else if (!isFinite(curr)) { |
| 122 | // Make sure NaNs are not double counted as Infs |
| 123 | numInfs += 1; |
| 124 | } |
| 125 | } |
| 126 | |
| 127 | let trueMin = minVal[0]; |
| 128 | let trueMax = maxVal[0]; |
| 129 | if (numNans === numVals) { |
| 130 | // on gpu the min and max won't be accurate if all values are NaN |
| 131 | trueMin = NaN; |
| 132 | trueMax = NaN; |
| 133 | } |
| 134 | |
| 135 | const stats = { |
| 136 | numVals, |
| 137 | numZeros: numZerosVal[0], |
| 138 | numNans, |
| 139 | min: trueMin, |
| 140 | max: trueMax, |
| 141 | numInfs, |
| 142 | }; |
| 143 | |
| 144 | return stats; |
| 145 | }); |
| 146 | } |
| 147 | |
| 148 | /** |
| 149 | * Computes a confusion matrix from predictions and labels. Each value in |