MCPcopy
hub / github.com/hojonathanho/diffusion / run

Method run

diffusion_tf/tpu_utils/tpu_utils.py:446–515  ·  view source on GitHub ↗

Runs the eval/sampling worker loop. Args: logdir: directory to read checkpoints from once: if True, writes results to a temporary directory (not to logdir), and exits after evaluating one checkpoint.

(self, logdir, once: bool, skip_non_ema_pass=True, dump_samples_only=False, load_ckpt=None, samples_dir=None, seed=0)

Source from the content-addressed store, hash-verified

444 print('done writing samples')
445
446 def run(self, logdir, once: bool, skip_non_ema_pass=True, dump_samples_only=False, load_ckpt=None, samples_dir=None, seed=0):
447 """Runs the eval/sampling worker loop.
448 Args:
449 logdir: directory to read checkpoints from
450 once: if True, writes results to a temporary directory (not to logdir),
451 and exits after evaluating one checkpoint.
452 """
453 tf.logging.set_verbosity(tf.logging.INFO)
454
455 # Are we evaluating a single checkpoint or looping on the latest?
456 if load_ckpt is not None:
457 # load_ckpt should be of the form: model.ckpt-1000000
458 assert tf.io.gfile.exists(os.path.join(logdir, load_ckpt) + '.index')
459 ckpt_iterator = [os.path.join(logdir, load_ckpt)] # load this one checkpoint only
460 else:
461 ckpt_iterator = tf.train.checkpoints_iterator(logdir) # wait for checkpoints to come in
462 assert tf.io.gfile.isdir(logdir), 'expected {} to be a directory'.format(logdir)
463
464 # Set up eval SummaryWriter
465 if once:
466 eval_logdir = os.path.join(logdir, 'eval_once_{}'.format(time.time()))
467 else:
468 eval_logdir = os.path.join(logdir, 'eval')
469 print('Writing eval data to: {}'.format(eval_logdir))
470 eval_log = utils.SummaryWriter(eval_logdir, write_graph=False)
471
472 # Make the session
473 config = tf.ConfigProto()
474 config.allow_soft_placement = True
475 cluster_spec = self.resolver.cluster_spec()
476 if cluster_spec:
477 config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
478 print('making session...')
479 with tf.Session(target=self.resolver.master(), config=config) as sess:
480
481 print('initializing global variables')
482 sess.run(tf.global_variables_initializer())
483
484 # Checkpoint loading
485 print('making saver')
486 saver = tf.train.Saver()
487
488 for ckpt in ckpt_iterator:
489 # Restore params
490 saver.restore(sess, ckpt)
491 global_step_val = sess.run(self.global_step)
492 print('restored global step: {}'.format(global_step_val))
493
494 print('seeding')
495 utils.seed_all(seed)
496
497 print('ema pass')
498 if dump_samples_only:
499 if not samples_dir:
500 samples_dir = os.path.join(eval_logdir, '{}_samples{}'.format(type(self.dataset).__name__, global_step_val))
501 self._dump_samples(
502 sess, curr_step=global_step_val, samples_dir=samples_dir, ema=True)
503 else:

Callers 8

evaluationFunction · 0.95
evaluationFunction · 0.95
evaluationFunction · 0.95
test_allFunction · 0.45
getMethod · 0.45
_run_samplingMethod · 0.45
_run_metricsMethod · 0.45
_dump_samplesMethod · 0.45

Calls 2

_dump_samplesMethod · 0.95

Tested by 1

test_allFunction · 0.36