Runs Image Classification model using native Keras APIs. Args: flags_obj: An object containing parsed flag values. strategy_override: A `tf.distribute.Strategy` object to use for model. Returns: Dictionary of training/eval stats
(flags_obj: flags.FlagValues,
strategy_override: tf.distribute.Strategy = None)
| 421 | |
| 422 | |
| 423 | def run(flags_obj: flags.FlagValues, |
| 424 | strategy_override: tf.distribute.Strategy = None) -> Mapping[str, Any]: |
| 425 | """Runs Image Classification model using native Keras APIs. |
| 426 | |
| 427 | Args: |
| 428 | flags_obj: An object containing parsed flag values. |
| 429 | strategy_override: A `tf.distribute.Strategy` object to use for model. |
| 430 | |
| 431 | Returns: |
| 432 | Dictionary of training/eval stats |
| 433 | """ |
| 434 | params = _get_params_from_flags(flags_obj) |
| 435 | if params.mode == 'train_and_eval': |
| 436 | return train_and_eval(params, strategy_override) |
| 437 | elif params.mode == 'export_only': |
| 438 | export(params) |
| 439 | else: |
| 440 | raise ValueError('{} is not a valid mode.'.format(params.mode)) |
| 441 | |
| 442 | |
| 443 | def main(_): |
no test coverage detected