(
labels: Tensor1D, predictions: Tensor1D,
numClasses?: number)
| 279 | * @doc {heading: 'Metrics', namespace: 'metrics'} |
| 280 | */ |
| 281 | export async function perClassAccuracy( |
| 282 | labels: Tensor1D, predictions: Tensor1D, |
| 283 | numClasses?: number): Promise<Array<{accuracy: number, count: number}>> { |
| 284 | assert(labels.rank === 1, 'labels must be a 1D tensor'); |
| 285 | assert(predictions.rank === 1, 'predictions must be a 1D tensor'); |
| 286 | assert( |
| 287 | labels.size === predictions.size, |
| 288 | 'labels and predictions must be the same length'); |
| 289 | |
| 290 | if (numClasses == null) { |
| 291 | const maximumTensor = tidy(() => { |
| 292 | return maximum(labels.max(), predictions.max()); |
| 293 | }); |
| 294 | const maximumArray = await maximumTensor.data(); |
| 295 | numClasses = maximumArray[0] + 1; |
| 296 | maximumTensor.dispose(); |
| 297 | } |
| 298 | |
| 299 | return Promise.all([labels.data(), predictions.data()]) |
| 300 | .then(([labelsArray, predsArray]) => { |
| 301 | // Per class total counts |
| 302 | const counts: number[] = Array(numClasses).fill(0); |
| 303 | // Per class accuracy |
| 304 | const accuracy: number[] = Array(numClasses).fill(0); |
| 305 | |
| 306 | for (let i = 0; i < labelsArray.length; i++) { |
| 307 | const label = labelsArray[i]; |
| 308 | const pred = predsArray[i]; |
| 309 | |
| 310 | counts[label] += 1; |
| 311 | if (label === pred) { |
| 312 | accuracy[label] += 1; |
| 313 | } |
| 314 | } |
| 315 | |
| 316 | const results: Array<{accuracy: number, count: number}> = []; |
| 317 | for (let i = 0; i < counts.length; i++) { |
| 318 | results.push({ |
| 319 | count: counts[i], |
| 320 | accuracy: counts[i] === 0 ? 0 : accuracy[i] / counts[i], |
| 321 | }); |
| 322 | } |
| 323 | |
| 324 | return results; |
| 325 | }); |
| 326 | } |
no test coverage detected
searching dependent graphs…