(
labels: Tensor, predictions: Tensor)
| 240 | * @doc {heading: 'Metrics', namespace: 'metrics'} |
| 241 | */ |
| 242 | export async function accuracy( |
| 243 | labels: Tensor, predictions: Tensor): Promise<number> { |
| 244 | assertShapesMatch( |
| 245 | labels.shape, predictions.shape, 'Error computing accuracy.'); |
| 246 | |
| 247 | const eq = labels.equal(predictions); |
| 248 | const mean = eq.mean(); |
| 249 | |
| 250 | const acc = (await mean.data())[0]; |
| 251 | |
| 252 | dispose([eq, mean]); |
| 253 | return acc; |
| 254 | } |
| 255 | |
| 256 | /** |
| 257 | * Computes per class accuracy between prediction and labels. Each value in |
no test coverage detected
searching dependent graphs…