Runs the eval/sampling worker loop. Args: logdir: directory to read checkpoints from once: if True, writes results to a temporary directory (not to logdir), and exits after evaluating one checkpoint.
(self, logdir, once: bool, skip_non_ema_pass=True, dump_samples_only=False, load_ckpt=None, samples_dir=None, seed=0)
| 444 | print('done writing samples') |
| 445 | |
| 446 | def run(self, logdir, once: bool, skip_non_ema_pass=True, dump_samples_only=False, load_ckpt=None, samples_dir=None, seed=0): |
| 447 | """Runs the eval/sampling worker loop. |
| 448 | Args: |
| 449 | logdir: directory to read checkpoints from |
| 450 | once: if True, writes results to a temporary directory (not to logdir), |
| 451 | and exits after evaluating one checkpoint. |
| 452 | """ |
| 453 | tf.logging.set_verbosity(tf.logging.INFO) |
| 454 | |
| 455 | # Are we evaluating a single checkpoint or looping on the latest? |
| 456 | if load_ckpt is not None: |
| 457 | # load_ckpt should be of the form: model.ckpt-1000000 |
| 458 | assert tf.io.gfile.exists(os.path.join(logdir, load_ckpt) + '.index') |
| 459 | ckpt_iterator = [os.path.join(logdir, load_ckpt)] # load this one checkpoint only |
| 460 | else: |
| 461 | ckpt_iterator = tf.train.checkpoints_iterator(logdir) # wait for checkpoints to come in |
| 462 | assert tf.io.gfile.isdir(logdir), 'expected {} to be a directory'.format(logdir) |
| 463 | |
| 464 | # Set up eval SummaryWriter |
| 465 | if once: |
| 466 | eval_logdir = os.path.join(logdir, 'eval_once_{}'.format(time.time())) |
| 467 | else: |
| 468 | eval_logdir = os.path.join(logdir, 'eval') |
| 469 | print('Writing eval data to: {}'.format(eval_logdir)) |
| 470 | eval_log = utils.SummaryWriter(eval_logdir, write_graph=False) |
| 471 | |
| 472 | # Make the session |
| 473 | config = tf.ConfigProto() |
| 474 | config.allow_soft_placement = True |
| 475 | cluster_spec = self.resolver.cluster_spec() |
| 476 | if cluster_spec: |
| 477 | config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) |
| 478 | print('making session...') |
| 479 | with tf.Session(target=self.resolver.master(), config=config) as sess: |
| 480 | |
| 481 | print('initializing global variables') |
| 482 | sess.run(tf.global_variables_initializer()) |
| 483 | |
| 484 | # Checkpoint loading |
| 485 | print('making saver') |
| 486 | saver = tf.train.Saver() |
| 487 | |
| 488 | for ckpt in ckpt_iterator: |
| 489 | # Restore params |
| 490 | saver.restore(sess, ckpt) |
| 491 | global_step_val = sess.run(self.global_step) |
| 492 | print('restored global step: {}'.format(global_step_val)) |
| 493 | |
| 494 | print('seeding') |
| 495 | utils.seed_all(seed) |
| 496 | |
| 497 | print('ema pass') |
| 498 | if dump_samples_only: |
| 499 | if not samples_dir: |
| 500 | samples_dir = os.path.join(eval_logdir, '{}_samples{}'.format(type(self.dataset).__name__, global_step_val)) |
| 501 | self._dump_samples( |
| 502 | sess, curr_step=global_step_val, samples_dir=samples_dir, ema=True) |
| 503 | else: |