MCPcopy
hub / github.com/tensorflow/tfjs / predictLoop

Method predictLoop

tfjs-layers/src/engine/training.ts:1033–1076  ·  view source on GitHub ↗

* Helper method to loop over some data in batches. * * Porting Note: Not using the functional approach in the Python equivalent * due to the imperative backend. * Porting Note: Does not support step mode currently. * * @param ins: input data * @param batchSize: integer batch s

(ins: Tensor|Tensor[], batchSize = 32, verbose = false)

Source from the content-addressed store, hash-verified

1031 * `tf.Tensor` (if multipe outputs).
1032 */
1033 private predictLoop(ins: Tensor|Tensor[], batchSize = 32, verbose = false):
1034 Tensor|Tensor[] {
1035 return tfc.tidy(() => {
1036 const numSamples = this.checkNumSamples(ins);
1037 if (verbose) {
1038 throw new NotImplementedError(
1039 'Verbose predictLoop() is not implemented yet.');
1040 }
1041
1042 // Sample-based predictions.
1043 // Porting Note: Tensor currently does not support sliced assignments as
1044 // in numpy, e.g., x[1:3] = y. Therefore we use concatenation while
1045 // iterating over the batches.
1046
1047 const batches = makeBatches(numSamples, batchSize);
1048 const outsBatches: Tensor[][] = this.outputs.map(output => []);
1049
1050 // TODO(cais): Can the scope() be pushed down inside the for loop?
1051 for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
1052 const batchOuts = tfc.tidy(() => {
1053 const batchStart = batches[batchIndex][0];
1054 const batchEnd = batches[batchIndex][1];
1055 // TODO(cais): Take care of the case of the last element is a flag for
1056 // training/test.
1057 const insBatch = sliceArrays(ins, batchStart, batchEnd);
1058
1059 // Construct the feeds for execute();
1060 const feeds = [];
1061 if (Array.isArray(insBatch)) {
1062 for (let i = 0; i < insBatch.length; ++i) {
1063 feeds.push({key: this.inputs[i], value: insBatch[i]});
1064 }
1065 } else {
1066 feeds.push({key: this.inputs[0], value: insBatch});
1067 }
1068 const feedDict = new FeedDict(feeds);
1069 return execute(this.outputs, feedDict) as Tensor[];
1070 });
1071 batchOuts.forEach((batchOut, i) => outsBatches[i].push(batchOut));
1072 }
1073 return singletonOrArray(
1074 outsBatches.map(batches => tfc.concat(batches, 0)));
1075 });
1076 }
1077
1078 /**
1079 * Generates output predictions for the input samples.

Callers 2

predictMethod · 0.95
predictOnBatchMethod · 0.95

Calls 8

checkNumSamplesMethod · 0.95
makeBatchesFunction · 0.90
sliceArraysFunction · 0.90
executeFunction · 0.90
singletonOrArrayFunction · 0.90
tidyMethod · 0.80
concatMethod · 0.65
pushMethod · 0.45

Tested by

no test coverage detected