Export models for inference.
| 22 | |
| 23 | |
| 24 | class ModelExporter(object): |
| 25 | """Export models for inference.""" |
| 26 | |
| 27 | def __init__(self, config): |
| 28 | """Initialise the export process. |
| 29 | |
| 30 | Args: |
| 31 | config (PredictConfig): the config to use. |
| 32 | The graph will be built with the tower function defined by this `PredictConfig`. |
| 33 | Then the input / output names will be used to export models for inference. |
| 34 | """ |
| 35 | super(ModelExporter, self).__init__() |
| 36 | self.config = config |
| 37 | |
| 38 | def export_compact(self, filename, optimize=True, toco_compatible=False): |
| 39 | """Create a self-contained inference-only graph and write final graph (in pb format) to disk. |
| 40 | |
| 41 | Args: |
| 42 | filename (str): path to the output graph |
| 43 | optimize (bool): whether to use TensorFlow's `optimize_for_inference` |
| 44 | to prune and optimize the graph. This does not work on all types of graphs. |
| 45 | toco_compatible (bool): See TensorFlow's |
| 46 | `optimize_for_inference |
| 47 | <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py>`_ |
| 48 | for details. Only available after TF 1.8. |
| 49 | """ |
| 50 | if toco_compatible: |
| 51 | assert optimize, "toco_compatible is only effective when optimize=True!" |
| 52 | self.graph = self.config._maybe_create_graph() |
| 53 | with self.graph.as_default(): |
| 54 | input = PlaceholderInput() |
| 55 | input.setup(self.config.input_signature) |
| 56 | with PredictTowerContext(''): |
| 57 | self.config.tower_func(*input.get_input_tensors()) |
| 58 | |
| 59 | input_tensors = get_tensors_by_names(self.config.input_names) |
| 60 | output_tensors = get_tensors_by_names(self.config.output_names) |
| 61 | |
| 62 | self.config.session_init._setup_graph() |
| 63 | # we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph |
| 64 | sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True)) |
| 65 | self.config.session_init._run_init(sess) |
| 66 | |
| 67 | dtypes = [n.dtype for n in input_tensors] |
| 68 | |
| 69 | # freeze variables to constants |
| 70 | frozen_graph_def = graph_util.convert_variables_to_constants( |
| 71 | sess, |
| 72 | self.graph.as_graph_def(), |
| 73 | [n.name[:-2] for n in output_tensors], |
| 74 | variable_names_whitelist=None, |
| 75 | variable_names_blacklist=None) |
| 76 | |
| 77 | # prune unused nodes from graph |
| 78 | if optimize: |
| 79 | toco_args = () if get_tf_version_tuple() < (1, 8) else (toco_compatible, ) |
| 80 | frozen_graph_def = optimize_for_inference_lib.optimize_for_inference( |
| 81 | frozen_graph_def, |
no outgoing calls
no test coverage detected