Runs the model export functionality.
(params: base_configs.ExperimentConfig)
| 406 | |
| 407 | |
| 408 | def export(params: base_configs.ExperimentConfig): |
| 409 | """Runs the model export functionality.""" |
| 410 | logging.info('Exporting model.') |
| 411 | model_params = params.model.model_params.as_dict() |
| 412 | model = get_models()[params.model.name](**model_params) |
| 413 | checkpoint = params.export.checkpoint |
| 414 | if checkpoint is None: |
| 415 | logging.info('No export checkpoint was provided. Using the latest ' |
| 416 | 'checkpoint from model_dir.') |
| 417 | checkpoint = tf.train.latest_checkpoint(params.model_dir) |
| 418 | |
| 419 | model.load_weights(checkpoint) |
| 420 | model.save(params.export.destination) |
| 421 | |
| 422 | |
| 423 | def run(flags_obj: flags.FlagValues, |
no test coverage detected