* Standardize the output of a dataset iterator for use by * LayersModel.fitDataset(). * * @param model: A `tf.LayersModel` object. * @param iteratorOut The output of a dataset iterator. It is required to be * an object of the form `{xs: TensorOrArrayOrMap, ys: * TensorOrArrayOrMap}`, where `
(
// Type `model` as `any` here to avoid circular dependency w/
// training.ts.
// tslint:disable-next-line:no-any
model: any, iteratorOut: {})
| 200 | * and `outputNames` of the model. |
| 201 | */ |
| 202 | function standardizeDataIteratorOutput( |
| 203 | // Type `model` as `any` here to avoid circular dependency w/ |
| 204 | // training.ts. |
| 205 | // tslint:disable-next-line:no-any |
| 206 | model: any, iteratorOut: {}): {xs: tfc.Tensor[], ys: tfc.Tensor[]} { |
| 207 | let xs: TensorOrArrayOrMap; |
| 208 | let ys: TensorOrArrayOrMap; |
| 209 | |
| 210 | const iteratorOutObj = iteratorOut as FitDatasetElement; |
| 211 | xs = iteratorOutObj['xs']; |
| 212 | ys = iteratorOutObj['ys']; |
| 213 | tfc.util.assert( |
| 214 | xs != null && ys != null, |
| 215 | () => 'A Dataset iterator for fitDataset() is expected to generate ' + |
| 216 | 'objects of the form `{xs: xVal, ys: yVal}`, where the two ' + |
| 217 | 'values may be `tf.Tensor`, an array of Tensors, or a map of ' + |
| 218 | 'string to Tensor. The provided Dataset instead generates ' + |
| 219 | `${iteratorOut}`); |
| 220 | |
| 221 | const flattenedXs: tfc.Tensor[] = |
| 222 | flattenTensorOrArrayOrMap('input', model.inputNames, xs); |
| 223 | const flattenedYs: tfc.Tensor[] = |
| 224 | flattenTensorOrArrayOrMap('output', model.outputNames, ys); |
| 225 | |
| 226 | const batchSize: number = flattenedXs[0].shape[0]; |
| 227 | |
| 228 | tfc.util.assert( |
| 229 | flattenedXs.length === model.inputs.length, |
| 230 | () => `LayersModel has ${model.inputs.length} inputs, but the dataset ` + |
| 231 | `provides ${flattenedXs.length} inputs. (Expected input keys: ` + |
| 232 | `${JSON.stringify(model.inputNames)})`); |
| 233 | |
| 234 | tfc.util.assert( |
| 235 | flattenedYs.length === model.outputs.length, |
| 236 | () => |
| 237 | `LayersModel has ${model.outputs.length} outputs, but the dataset ` + |
| 238 | `provides ${flattenedYs.length} outputs. (Expected output keys: ` + |
| 239 | `${JSON.stringify(model.outputNames)})`); |
| 240 | |
| 241 | for (let xIndex = 0; xIndex < flattenedXs.length; xIndex++) { |
| 242 | tfc.util.assert( |
| 243 | flattenedXs[xIndex].shape[0] === batchSize, |
| 244 | () => `Batch size mismatch: input ` + |
| 245 | `${model.inputNames[xIndex]} has ${ |
| 246 | flattenedXs[xIndex].shape[0]}; ` + |
| 247 | `expected ${batchSize} based on input ${model.inputNames[0]}.`); |
| 248 | } |
| 249 | |
| 250 | for (let yIndex = 0; yIndex < flattenedYs.length; yIndex++) { |
| 251 | tfc.util.assert( |
| 252 | flattenedYs[yIndex].shape[0] === batchSize, |
| 253 | () => `Batch size mismatch: output ` + |
| 254 | `${model.outputNames[yIndex]} has ${ |
| 255 | flattenedYs[yIndex].shape[0]}; ` + |
| 256 | `expected ${batchSize} based on input ${model.inputNames[0]}.`); |
| 257 | } |
| 258 | |
| 259 | return {xs: flattenedXs, ys: flattenedYs}; |
no test coverage detected
searching dependent graphs…