(self, func)
| 123 | return reader, set(ckpt_vars) |
| 124 | |
| 125 | def _match_vars(self, func): |
| 126 | reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path) |
| 127 | graph_vars = tf.global_variables() |
| 128 | chkpt_vars_used = set() |
| 129 | |
| 130 | mismatch = MismatchLogger('graph', 'checkpoint') |
| 131 | for v in graph_vars: |
| 132 | name = get_savename_from_varname(v.name, varname_prefix=self.prefix) |
| 133 | if name in self.ignore and reader.has_tensor(name): |
| 134 | logger.info("Variable {} in the graph will not be loaded from the checkpoint!".format(name)) |
| 135 | else: |
| 136 | if reader.has_tensor(name): |
| 137 | func(reader, name, v) |
| 138 | chkpt_vars_used.add(name) |
| 139 | else: |
| 140 | # use tensor name (instead of op name) for logging, to be consistent with the reverse case |
| 141 | if not is_training_name(v.name): |
| 142 | mismatch.add(v.name) |
| 143 | mismatch.log() |
| 144 | mismatch = MismatchLogger('checkpoint', 'graph') |
| 145 | if len(chkpt_vars_used) < len(chkpt_vars): |
| 146 | unused = chkpt_vars - chkpt_vars_used |
| 147 | for name in sorted(unused): |
| 148 | if not is_training_name(name): |
| 149 | mismatch.add(name) |
| 150 | mismatch.log() |
| 151 | |
| 152 | def _get_restore_dict(self): |
| 153 | var_dict = {} |
no test coverage detected