* Configures and prepares the model for training and evaluation. Compiling * outfits the model with an optimizer, loss, and/or metrics. Calling `fit` * or `evaluate` on an un-compiled model will throw an error. * * @param args a `ModelCompileArgs` specifying the loss, optimizer, and
(args: ModelCompileArgs)
| 581 | * @doc {heading: 'Models', subheading: 'Classes'} |
| 582 | */ |
| 583 | compile(args: ModelCompileArgs): void { |
| 584 | if (args.loss == null) { |
| 585 | args.loss = []; |
| 586 | } |
| 587 | this.loss = args.loss; |
| 588 | |
| 589 | if (typeof args.optimizer === 'string') { |
| 590 | this.optimizer_ = optimizers.getOptimizer(args.optimizer); |
| 591 | this.isOptimizerOwned = true; |
| 592 | } else { |
| 593 | if (!(args.optimizer instanceof Optimizer)) { |
| 594 | throw new ValueError( |
| 595 | `User-defined optimizer must be an instance of tf.Optimizer.`); |
| 596 | } |
| 597 | this.optimizer_ = args.optimizer; |
| 598 | this.isOptimizerOwned = false; |
| 599 | } |
| 600 | |
| 601 | // TODO(cais): Add lossWeights. |
| 602 | // TODO(cais): Add sampleWeightMode. |
| 603 | |
| 604 | // Prepare loss functions. |
| 605 | let lossFunctions: LossOrMetricFn[] = []; |
| 606 | if (!Array.isArray(args.loss) && typeof args.loss !== 'string' && |
| 607 | typeof args.loss !== 'function') { |
| 608 | args.loss = args.loss as {[outputName: string]: string}; |
| 609 | for (const name in args.loss) { |
| 610 | if (this.outputNames.indexOf(name) === -1) { |
| 611 | throw new ValueError( |
| 612 | `Unknown entry in loss dictionary: "${name}". ` + |
| 613 | `Only expected the following keys: ${this.outputNames}`); |
| 614 | } |
| 615 | } |
| 616 | for (const name of this.outputNames) { |
| 617 | if (args.loss[name] == null) { |
| 618 | console.warn( |
| 619 | `Output "${name}" is missing from loss dictionary. We assume ` + |
| 620 | `this was done on purpose, and we will not be expecting data ` + |
| 621 | `to be passed to ${name} during training`); |
| 622 | } |
| 623 | lossFunctions.push(losses.get(args.loss[name])); |
| 624 | } |
| 625 | } else if (Array.isArray(args.loss)) { |
| 626 | if (args.loss.length !== this.outputs.length) { |
| 627 | throw new ValueError( |
| 628 | `When passing an Array as loss, it should have one entry per ` + |
| 629 | `model output. The model has ${this.outputs.length} output(s), ` + |
| 630 | `but you passed loss=${args.loss}.`); |
| 631 | } |
| 632 | const theLosses = args.loss as Array<string|LossOrMetricFn>; |
| 633 | lossFunctions = theLosses.map(l => losses.get(l)); |
| 634 | } else { |
| 635 | const lossFunction = losses.get(args.loss); |
| 636 | this.outputs.forEach(_ => { |
| 637 | lossFunctions.push(lossFunction); |
| 638 | }); |
| 639 | } |
| 640 |
no test coverage detected