MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / ModelExporter

Class ModelExporter

tensorpack/tfutils/export.py:24–149  ·  view source on GitHub ↗

Export models for inference.

Source from the content-addressed store, hash-verified

22
23
24class ModelExporter(object):
25 """Export models for inference."""
26
27 def __init__(self, config):
28 """Initialise the export process.
29
30 Args:
31 config (PredictConfig): the config to use.
32 The graph will be built with the tower function defined by this `PredictConfig`.
33 Then the input / output names will be used to export models for inference.
34 """
35 super(ModelExporter, self).__init__()
36 self.config = config
37
38 def export_compact(self, filename, optimize=True, toco_compatible=False):
39 """Create a self-contained inference-only graph and write final graph (in pb format) to disk.
40
41 Args:
42 filename (str): path to the output graph
43 optimize (bool): whether to use TensorFlow's `optimize_for_inference`
44 to prune and optimize the graph. This does not work on all types of graphs.
45 toco_compatible (bool): See TensorFlow's
46 `optimize_for_inference
47 <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py>`_
48 for details. Only available after TF 1.8.
49 """
50 if toco_compatible:
51 assert optimize, "toco_compatible is only effective when optimize=True!"
52 self.graph = self.config._maybe_create_graph()
53 with self.graph.as_default():
54 input = PlaceholderInput()
55 input.setup(self.config.input_signature)
56 with PredictTowerContext(''):
57 self.config.tower_func(*input.get_input_tensors())
58
59 input_tensors = get_tensors_by_names(self.config.input_names)
60 output_tensors = get_tensors_by_names(self.config.output_names)
61
62 self.config.session_init._setup_graph()
63 # we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph
64 sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True))
65 self.config.session_init._run_init(sess)
66
67 dtypes = [n.dtype for n in input_tensors]
68
69 # freeze variables to constants
70 frozen_graph_def = graph_util.convert_variables_to_constants(
71 sess,
72 self.graph.as_graph_def(),
73 [n.name[:-2] for n in output_tensors],
74 variable_names_whitelist=None,
75 variable_names_blacklist=None)
76
77 # prune unused nodes from graph
78 if optimize:
79 toco_args = () if get_tf_version_tuple() < (1, 8) else (toco_compatible, )
80 frozen_graph_def = optimize_for_inference_lib.optimize_for_inference(
81 frozen_graph_def,

Callers 3

predict.pyFile · 0.90
export_servingFunction · 0.90
export_compactFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected