MCPcopy
hub / github.com/tensorflow/tfjs / compile

Method compile

tfjs-layers/src/engine/training.ts:583–785  ·  view source on GitHub ↗

* Configures and prepares the model for training and evaluation. Compiling * outfits the model with an optimizer, loss, and/or metrics. Calling `fit` * or `evaluate` on an un-compiled model will throw an error. * * @param args a `ModelCompileArgs` specifying the loss, optimizer, and

(args: ModelCompileArgs)

Source from the content-addressed store, hash-verified

581 * @doc {heading: 'Models', subheading: 'Classes'}
582 */
583 compile(args: ModelCompileArgs): void {
584 if (args.loss == null) {
585 args.loss = [];
586 }
587 this.loss = args.loss;
588
589 if (typeof args.optimizer === 'string') {
590 this.optimizer_ = optimizers.getOptimizer(args.optimizer);
591 this.isOptimizerOwned = true;
592 } else {
593 if (!(args.optimizer instanceof Optimizer)) {
594 throw new ValueError(
595 `User-defined optimizer must be an instance of tf.Optimizer.`);
596 }
597 this.optimizer_ = args.optimizer;
598 this.isOptimizerOwned = false;
599 }
600
601 // TODO(cais): Add lossWeights.
602 // TODO(cais): Add sampleWeightMode.
603
604 // Prepare loss functions.
605 let lossFunctions: LossOrMetricFn[] = [];
606 if (!Array.isArray(args.loss) && typeof args.loss !== 'string' &&
607 typeof args.loss !== 'function') {
608 args.loss = args.loss as {[outputName: string]: string};
609 for (const name in args.loss) {
610 if (this.outputNames.indexOf(name) === -1) {
611 throw new ValueError(
612 `Unknown entry in loss dictionary: "${name}". ` +
613 `Only expected the following keys: ${this.outputNames}`);
614 }
615 }
616 for (const name of this.outputNames) {
617 if (args.loss[name] == null) {
618 console.warn(
619 `Output "${name}" is missing from loss dictionary. We assume ` +
620 `this was done on purpose, and we will not be expecting data ` +
621 `to be passed to ${name} during training`);
622 }
623 lossFunctions.push(losses.get(args.loss[name]));
624 }
625 } else if (Array.isArray(args.loss)) {
626 if (args.loss.length !== this.outputs.length) {
627 throw new ValueError(
628 `When passing an Array as loss, it should have one entry per ` +
629 `model output. The model has ${this.outputs.length} output(s), ` +
630 `but you passed loss=${args.loss}.`);
631 }
632 const theLosses = args.loss as Array<string|LossOrMetricFn>;
633 lossFunctions = theLosses.map(l => losses.get(l));
634 } else {
635 const lossFunction = losses.get(args.loss);
636 this.outputs.forEach(_ => {
637 lossFunctions.push(lossFunction);
638 });
639 }
640

Callers 6

loadTrainingConfigMethod · 0.95
training_test.tsFile · 0.45
createDummyModelFunction · 0.45
container_test.tsFile · 0.45

Calls 4

nameScopeFunction · 0.90
collectMetricsFunction · 0.85
pushMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected