(
// Type `model` as `any` here to avoid circular dependency w/
// training.ts.
// tslint:disable-next-line:no-any
model: any, dataset: Dataset<T>|LazyIterator<T>,
args: ModelEvaluateDatasetArgs)
| 531 | } |
| 532 | |
| 533 | export async function evaluateDataset<T>( |
| 534 | // Type `model` as `any` here to avoid circular dependency w/ |
| 535 | // training.ts. |
| 536 | // tslint:disable-next-line:no-any |
| 537 | model: any, dataset: Dataset<T>|LazyIterator<T>, |
| 538 | args: ModelEvaluateDatasetArgs): Promise<tfc.Scalar|tfc.Scalar[]> { |
| 539 | args = args || {}; |
| 540 | const hasBatches = args.batches != null; |
| 541 | const f = model.testFunction; |
| 542 | let outs: tfc.Scalar[] = []; |
| 543 | if (args.verbose > 0) { |
| 544 | throw new NotImplementedError('Verbose mode is not implemented yet.'); |
| 545 | } |
| 546 | |
| 547 | tfc.util.assert( |
| 548 | !hasBatches || (args.batches > 0 && Number.isInteger(args.batches)), |
| 549 | () => 'Test loop expects `batches` to be a positive integer, but ' + |
| 550 | `received ${JSON.stringify(args.batches)}`); |
| 551 | const dataIterator = isLazyIteratorObject(dataset) ? |
| 552 | dataset as LazyIterator<T>: |
| 553 | await (dataset as Dataset<T>).iterator(); |
| 554 | // Keeps track of number of examples used in this evaluation. |
| 555 | let numExamples = 0; |
| 556 | let batch = 0; |
| 557 | |
| 558 | while (hasBatches ? batch < args.batches : true) { |
| 559 | const iteratorOut = await dataIterator.next(); |
| 560 | outs = tfc.tidy(() => { |
| 561 | if (iteratorOut.value) { |
| 562 | // TODO(cais): Once real dataset is available, use |
| 563 | // `map(x => standardizeDataIteratorOutput(model, x).map(f)`. |
| 564 | const {xs, ys} = |
| 565 | standardizeDataIteratorOutput(model, iteratorOut.value); |
| 566 | const xsAndYs = xs.concat(ys); |
| 567 | const batchOuts = tfc.tidy(() => f(xsAndYs)); |
| 568 | tfc.dispose(xsAndYs); |
| 569 | |
| 570 | if (batch === 0) { |
| 571 | for (let i = 0; i < batchOuts.length; ++i) { |
| 572 | outs.push(scalar(0)); |
| 573 | } |
| 574 | } |
| 575 | |
| 576 | const batchSize = xsAndYs[0].shape[0]; |
| 577 | for (let i = 0; i < batchOuts.length; ++i) { |
| 578 | const batchOut = batchOuts[i]; |
| 579 | const oldScalar = outs[i]; |
| 580 | outs[i] = |
| 581 | tfc.tidy(() => tfc.add(outs[i], tfc.mul(batchSize, batchOut))); |
| 582 | if (batch > 0) { |
| 583 | tfc.dispose(oldScalar); |
| 584 | } |
| 585 | } |
| 586 | tfc.dispose(batchOuts); |
| 587 | numExamples += batchSize; |
| 588 | |
| 589 | ++batch; |
| 590 | } |
no test coverage detected
searching dependent graphs…