MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / _match_vars

Method _match_vars

tensorpack/tfutils/sessinit.py:125–150  ·  view source on GitHub ↗
(self, func)

Source from the content-addressed store, hash-verified

123 return reader, set(ckpt_vars)
124
125 def _match_vars(self, func):
126 reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
127 graph_vars = tf.global_variables()
128 chkpt_vars_used = set()
129
130 mismatch = MismatchLogger('graph', 'checkpoint')
131 for v in graph_vars:
132 name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
133 if name in self.ignore and reader.has_tensor(name):
134 logger.info("Variable {} in the graph will not be loaded from the checkpoint!".format(name))
135 else:
136 if reader.has_tensor(name):
137 func(reader, name, v)
138 chkpt_vars_used.add(name)
139 else:
140 # use tensor name (instead of op name) for logging, to be consistent with the reverse case
141 if not is_training_name(v.name):
142 mismatch.add(v.name)
143 mismatch.log()
144 mismatch = MismatchLogger('checkpoint', 'graph')
145 if len(chkpt_vars_used) < len(chkpt_vars):
146 unused = chkpt_vars - chkpt_vars_used
147 for name in sorted(unused):
148 if not is_training_name(name):
149 mismatch.add(name)
150 mismatch.log()
151
152 def _get_restore_dict(self):
153 var_dict = {}

Callers 2

_get_restore_dictMethod · 0.95
_run_initMethod · 0.80

Calls 8

addMethod · 0.95
logMethod · 0.95
MismatchLoggerClass · 0.85
is_training_nameFunction · 0.85
_read_checkpoint_varsMethod · 0.80
has_tensorMethod · 0.80
formatMethod · 0.80

Tested by

no test coverage detected