Exports a trained model as a `SavedModel` for inference. Args: model_export_path: a string specifying the path to the SavedModel directory. input_meta_data: dictionary containing meta data about input and model. bert_config: Bert configuration file to define core bert layers. mode
(model_export_path, input_meta_data, bert_config,
model_dir)
| 327 | |
| 328 | |
| 329 | def export_classifier(model_export_path, input_meta_data, bert_config, |
| 330 | model_dir): |
| 331 | """Exports a trained model as a `SavedModel` for inference. |
| 332 | |
| 333 | Args: |
| 334 | model_export_path: a string specifying the path to the SavedModel directory. |
| 335 | input_meta_data: dictionary containing meta data about input and model. |
| 336 | bert_config: Bert configuration file to define core bert layers. |
| 337 | model_dir: The directory where the model weights and training/evaluation |
| 338 | summaries are stored. |
| 339 | |
| 340 | Raises: |
| 341 | Export path is not specified, got an empty string or None. |
| 342 | """ |
| 343 | if not model_export_path: |
| 344 | raise ValueError('Export path is not specified: %s' % model_export_path) |
| 345 | if not model_dir: |
| 346 | raise ValueError('Export path is not specified: %s' % model_dir) |
| 347 | |
| 348 | # Export uses float32 for now, even if training uses mixed precision. |
| 349 | tf_keras.mixed_precision.set_global_policy('float32') |
| 350 | classifier_model = bert_models.classifier_model( |
| 351 | bert_config, |
| 352 | input_meta_data.get('num_labels', 1), |
| 353 | hub_module_url=FLAGS.hub_module_url, |
| 354 | hub_module_trainable=False)[0] |
| 355 | |
| 356 | model_saving_utils.export_bert_model( |
| 357 | model_export_path, model=classifier_model, checkpoint_dir=model_dir) |
| 358 | |
| 359 | |
| 360 | def run_bert(strategy, |