| 277 | return img_summary |
| 278 | |
| 279 | def build_model(self, sess): |
| 280 | self.init_opt() |
| 281 | sess.run(tf.initialize_all_variables()) |
| 282 | |
| 283 | if len(self.model_path) > 0: |
| 284 | print("Reading model parameters from %s" % self.model_path) |
| 285 | restore_vars = tf.all_variables() |
| 286 | # all_vars = tf.all_variables() |
| 287 | # restore_vars = [var for var in all_vars if |
| 288 | # var.name.startswith('g_') or |
| 289 | # var.name.startswith('d_')] |
| 290 | saver = tf.train.Saver(restore_vars) |
| 291 | saver.restore(sess, self.model_path) |
| 292 | |
| 293 | istart = self.model_path.rfind('_') + 1 |
| 294 | iend = self.model_path.rfind('.') |
| 295 | counter = self.model_path[istart:iend] |
| 296 | counter = int(counter) |
| 297 | else: |
| 298 | print("Created model with fresh parameters.") |
| 299 | counter = 0 |
| 300 | return counter |
| 301 | |
| 302 | def train(self): |
| 303 | config = tf.ConfigProto(allow_soft_placement=True) |