* Helper method to loop over some data in batches. * * Porting Note: Not using the functional approach in the Python equivalent * due to the imperative backend. * Porting Note: Does not support step mode currently. * * @param ins: input data * @param batchSize: integer batch s
(ins: Tensor|Tensor[], batchSize = 32, verbose = false)
| 1031 | * `tf.Tensor` (if multipe outputs). |
| 1032 | */ |
| 1033 | private predictLoop(ins: Tensor|Tensor[], batchSize = 32, verbose = false): |
| 1034 | Tensor|Tensor[] { |
| 1035 | return tfc.tidy(() => { |
| 1036 | const numSamples = this.checkNumSamples(ins); |
| 1037 | if (verbose) { |
| 1038 | throw new NotImplementedError( |
| 1039 | 'Verbose predictLoop() is not implemented yet.'); |
| 1040 | } |
| 1041 | |
| 1042 | // Sample-based predictions. |
| 1043 | // Porting Note: Tensor currently does not support sliced assignments as |
| 1044 | // in numpy, e.g., x[1:3] = y. Therefore we use concatenation while |
| 1045 | // iterating over the batches. |
| 1046 | |
| 1047 | const batches = makeBatches(numSamples, batchSize); |
| 1048 | const outsBatches: Tensor[][] = this.outputs.map(output => []); |
| 1049 | |
| 1050 | // TODO(cais): Can the scope() be pushed down inside the for loop? |
| 1051 | for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) { |
| 1052 | const batchOuts = tfc.tidy(() => { |
| 1053 | const batchStart = batches[batchIndex][0]; |
| 1054 | const batchEnd = batches[batchIndex][1]; |
| 1055 | // TODO(cais): Take care of the case of the last element is a flag for |
| 1056 | // training/test. |
| 1057 | const insBatch = sliceArrays(ins, batchStart, batchEnd); |
| 1058 | |
| 1059 | // Construct the feeds for execute(); |
| 1060 | const feeds = []; |
| 1061 | if (Array.isArray(insBatch)) { |
| 1062 | for (let i = 0; i < insBatch.length; ++i) { |
| 1063 | feeds.push({key: this.inputs[i], value: insBatch[i]}); |
| 1064 | } |
| 1065 | } else { |
| 1066 | feeds.push({key: this.inputs[0], value: insBatch}); |
| 1067 | } |
| 1068 | const feedDict = new FeedDict(feeds); |
| 1069 | return execute(this.outputs, feedDict) as Tensor[]; |
| 1070 | }); |
| 1071 | batchOuts.forEach((batchOut, i) => outsBatches[i].push(batchOut)); |
| 1072 | } |
| 1073 | return singletonOrArray( |
| 1074 | outsBatches.map(batches => tfc.concat(batches, 0))); |
| 1075 | }); |
| 1076 | } |
| 1077 | |
| 1078 | /** |
| 1079 | * Generates output predictions for the input samples. |
no test coverage detected