(trainingConfig: TrainingConfig)
| 1989 | } |
| 1990 | |
| 1991 | loadTrainingConfig(trainingConfig: TrainingConfig) { |
| 1992 | if (trainingConfig.weighted_metrics != null) { |
| 1993 | throw new Error('Loading weight_metrics is not supported yet.'); |
| 1994 | } |
| 1995 | if (trainingConfig.loss_weights != null) { |
| 1996 | throw new Error('Loading loss_weights is not supported yet.'); |
| 1997 | } |
| 1998 | if (trainingConfig.sample_weight_mode != null) { |
| 1999 | throw new Error('Loading sample_weight_mode is not supported yet.'); |
| 2000 | } |
| 2001 | |
| 2002 | const tsConfig = convertPythonicToTs(trainingConfig.optimizer_config) as |
| 2003 | serialization.ConfigDict; |
| 2004 | const optimizer = deserialize(tsConfig) as Optimizer; |
| 2005 | |
| 2006 | let loss; |
| 2007 | if (typeof trainingConfig.loss === 'string') { |
| 2008 | loss = toCamelCase(trainingConfig.loss); |
| 2009 | } else if (Array.isArray(trainingConfig.loss)) { |
| 2010 | loss = trainingConfig.loss.map(lossEntry => toCamelCase(lossEntry)); |
| 2011 | } else if (trainingConfig.loss != null) { |
| 2012 | loss = {} as {[outputName: string]: LossIdentifier}; |
| 2013 | for (const key in trainingConfig.loss) { |
| 2014 | loss[key] = toCamelCase(trainingConfig.loss[key]) as LossIdentifier; |
| 2015 | } |
| 2016 | } |
| 2017 | |
| 2018 | let metrics; |
| 2019 | if (Array.isArray(trainingConfig.metrics)) { |
| 2020 | metrics = trainingConfig.metrics.map(metric => toCamelCase(metric)); |
| 2021 | } else if (trainingConfig.metrics != null) { |
| 2022 | metrics = {} as {[outputName: string]: MetricsIdentifier}; |
| 2023 | for (const key in trainingConfig.metrics) { |
| 2024 | metrics[key] = toCamelCase(trainingConfig.metrics[key]); |
| 2025 | } |
| 2026 | } |
| 2027 | |
| 2028 | this.compile({loss, metrics, optimizer}); |
| 2029 | } |
| 2030 | |
| 2031 | /** |
| 2032 | * Save the configuration and/or weights of the LayersModel. |
no test coverage detected