Runs the experiment.
(callbacks=None)
| 173 | |
| 174 | |
| 175 | def run(callbacks=None): |
| 176 | """Runs the experiment.""" |
| 177 | keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) |
| 178 | |
| 179 | params = config_factory.config_generator(FLAGS.model) |
| 180 | |
| 181 | params = params_dict.override_params_dict( |
| 182 | params, FLAGS.config_file, is_strict=True) |
| 183 | |
| 184 | params = params_dict.override_params_dict( |
| 185 | params, FLAGS.params_override, is_strict=True) |
| 186 | params.override( |
| 187 | { |
| 188 | 'strategy_type': FLAGS.strategy_type, |
| 189 | 'model_dir': FLAGS.model_dir, |
| 190 | 'strategy_config': executor.strategy_flags_dict(), |
| 191 | }, |
| 192 | is_strict=False) |
| 193 | |
| 194 | # Make sure use_tpu and strategy_type are in sync. |
| 195 | params.use_tpu = (params.strategy_type == 'tpu') |
| 196 | |
| 197 | if not params.use_tpu: |
| 198 | params.override({ |
| 199 | 'architecture': { |
| 200 | 'use_bfloat16': False, |
| 201 | }, |
| 202 | 'norm_activation': { |
| 203 | 'use_sync_bn': False, |
| 204 | }, |
| 205 | }, is_strict=True) |
| 206 | |
| 207 | params.validate() |
| 208 | params.lock() |
| 209 | pp = pprint.PrettyPrinter() |
| 210 | params_str = pp.pformat(params.as_dict()) |
| 211 | logging.info('Model Parameters: %s', params_str) |
| 212 | |
| 213 | train_input_fn = None |
| 214 | eval_input_fn = None |
| 215 | training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern |
| 216 | eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern |
| 217 | if not training_file_pattern and not eval_file_pattern: |
| 218 | raise ValueError('Must provide at least one of training_file_pattern and ' |
| 219 | 'eval_file_pattern.') |
| 220 | |
| 221 | if training_file_pattern: |
| 222 | # Use global batch size for single host. |
| 223 | train_input_fn = input_reader.InputFn( |
| 224 | file_pattern=training_file_pattern, |
| 225 | params=params, |
| 226 | mode=input_reader.ModeKeys.TRAIN, |
| 227 | batch_size=params.train.batch_size) |
| 228 | |
| 229 | if eval_file_pattern: |
| 230 | eval_input_fn = input_reader.InputFn( |
| 231 | file_pattern=eval_file_pattern, |
| 232 | params=params, |