Args: prms(dict): dict of {variable name: value} Any name in prms must be in the graph and in vars_to_update.
(self, prms)
| 121 | return value |
| 122 | |
| 123 | def update(self, prms): |
| 124 | """ |
| 125 | Args: |
| 126 | prms(dict): dict of {variable name: value} |
| 127 | Any name in prms must be in the graph and in vars_to_update. |
| 128 | """ |
| 129 | with self.sess.as_default(): |
| 130 | fetches = [] |
| 131 | feeds = {} |
| 132 | for name, value in six.iteritems(prms): |
| 133 | assert name in self.name_map |
| 134 | var = self.name_map[name] |
| 135 | value = SessionUpdate.relaxed_value_for_var( |
| 136 | value, var, ignore_mismatch=self.ignore_mismatch) |
| 137 | # This is the implementation of `var.load` |
| 138 | if value is not None: |
| 139 | fetches.append(var.initializer) |
| 140 | feeds[var.initializer.inputs[1]] = value |
| 141 | self.sess.run(fetches, feed_dict=feeds) |
| 142 | |
| 143 | |
| 144 | def dump_session_params(path): |
no test coverage detected