MCPcopy
hub / github.com/tensorflow/tfjs / perClassAccuracy

Function perClassAccuracy

tfjs-vis/src/util/math.ts:281–326  ·  view source on GitHub ↗
(
    labels: Tensor1D, predictions: Tensor1D,
    numClasses?: number)

Source from the content-addressed store, hash-verified

279 * @doc {heading: 'Metrics', namespace: 'metrics'}
280 */
281export async function perClassAccuracy(
282 labels: Tensor1D, predictions: Tensor1D,
283 numClasses?: number): Promise<Array<{accuracy: number, count: number}>> {
284 assert(labels.rank === 1, 'labels must be a 1D tensor');
285 assert(predictions.rank === 1, 'predictions must be a 1D tensor');
286 assert(
287 labels.size === predictions.size,
288 'labels and predictions must be the same length');
289
290 if (numClasses == null) {
291 const maximumTensor = tidy(() => {
292 return maximum(labels.max(), predictions.max());
293 });
294 const maximumArray = await maximumTensor.data();
295 numClasses = maximumArray[0] + 1;
296 maximumTensor.dispose();
297 }
298
299 return Promise.all([labels.data(), predictions.data()])
300 .then(([labelsArray, predsArray]) => {
301 // Per class total counts
302 const counts: number[] = Array(numClasses).fill(0);
303 // Per class accuracy
304 const accuracy: number[] = Array(numClasses).fill(0);
305
306 for (let i = 0; i < labelsArray.length; i++) {
307 const label = labelsArray[i];
308 const pred = predsArray[i];
309
310 counts[label] += 1;
311 if (label === pred) {
312 accuracy[label] += 1;
313 }
314 }
315
316 const results: Array<{accuracy: number, count: number}> = [];
317 for (let i = 0; i < counts.length; i++) {
318 results.push({
319 count: counts[i],
320 accuracy: counts[i] === 0 ? 0 : accuracy[i] / counts[i],
321 });
322 }
323
324 return results;
325 });
326}

Callers 1

math_test.tsFile · 0.90

Calls 8

assertFunction · 0.90
tidyFunction · 0.90
maxMethod · 0.80
allMethod · 0.80
dataMethod · 0.65
maximumFunction · 0.50
disposeMethod · 0.45
pushMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…