MCPcopy
hub / github.com/tensorflow/tfjs / standardizeDataIteratorOutput

Function standardizeDataIteratorOutput

tfjs-layers/src/engine/training_dataset.ts:202–260  ·  view source on GitHub ↗

* Standardize the output of a dataset iterator for use by * LayersModel.fitDataset(). * * @param model: A `tf.LayersModel` object. * @param iteratorOut The output of a dataset iterator. It is required to be * an object of the form `{xs: TensorOrArrayOrMap, ys: * TensorOrArrayOrMap}`, where `

(
    // Type `model` as `any` here to avoid circular dependency w/
    // training.ts.
    // tslint:disable-next-line:no-any
    model: any, iteratorOut: {})

Source from the content-addressed store, hash-verified

200 * and `outputNames` of the model.
201 */
202function standardizeDataIteratorOutput(
203 // Type `model` as `any` here to avoid circular dependency w/
204 // training.ts.
205 // tslint:disable-next-line:no-any
206 model: any, iteratorOut: {}): {xs: tfc.Tensor[], ys: tfc.Tensor[]} {
207 let xs: TensorOrArrayOrMap;
208 let ys: TensorOrArrayOrMap;
209
210 const iteratorOutObj = iteratorOut as FitDatasetElement;
211 xs = iteratorOutObj['xs'];
212 ys = iteratorOutObj['ys'];
213 tfc.util.assert(
214 xs != null && ys != null,
215 () => 'A Dataset iterator for fitDataset() is expected to generate ' +
216 'objects of the form `{xs: xVal, ys: yVal}`, where the two ' +
217 'values may be `tf.Tensor`, an array of Tensors, or a map of ' +
218 'string to Tensor. The provided Dataset instead generates ' +
219 `${iteratorOut}`);
220
221 const flattenedXs: tfc.Tensor[] =
222 flattenTensorOrArrayOrMap('input', model.inputNames, xs);
223 const flattenedYs: tfc.Tensor[] =
224 flattenTensorOrArrayOrMap('output', model.outputNames, ys);
225
226 const batchSize: number = flattenedXs[0].shape[0];
227
228 tfc.util.assert(
229 flattenedXs.length === model.inputs.length,
230 () => `LayersModel has ${model.inputs.length} inputs, but the dataset ` +
231 `provides ${flattenedXs.length} inputs. (Expected input keys: ` +
232 `${JSON.stringify(model.inputNames)})`);
233
234 tfc.util.assert(
235 flattenedYs.length === model.outputs.length,
236 () =>
237 `LayersModel has ${model.outputs.length} outputs, but the dataset ` +
238 `provides ${flattenedYs.length} outputs. (Expected output keys: ` +
239 `${JSON.stringify(model.outputNames)})`);
240
241 for (let xIndex = 0; xIndex < flattenedXs.length; xIndex++) {
242 tfc.util.assert(
243 flattenedXs[xIndex].shape[0] === batchSize,
244 () => `Batch size mismatch: input ` +
245 `${model.inputNames[xIndex]} has ${
246 flattenedXs[xIndex].shape[0]}; ` +
247 `expected ${batchSize} based on input ${model.inputNames[0]}.`);
248 }
249
250 for (let yIndex = 0; yIndex < flattenedYs.length; yIndex++) {
251 tfc.util.assert(
252 flattenedYs[yIndex].shape[0] === batchSize,
253 () => `Batch size mismatch: output ` +
254 `${model.outputNames[yIndex]} has ${
255 flattenedYs[yIndex].shape[0]}; ` +
256 `expected ${batchSize} based on input ${model.inputNames[0]}.`);
257 }
258
259 return {xs: flattenedXs, ys: flattenedYs};

Callers 2

fitDatasetFunction · 0.85
evaluateDatasetFunction · 0.85

Calls 1

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…