MCPcopy
hub / github.com/hojonathanho/diffusion / _dump_samples

Method _dump_samples

diffusion_tf/tpu_utils/tpu_utils.py:406–444  ·  view source on GitHub ↗
(self, sess, curr_step, samples_dir, ema: bool, num_samples=50000)

Source from the content-addressed store, hash-verified

404 log.flush()
405
406 def _dump_samples(self, sess, curr_step, samples_dir, ema: bool, num_samples=50000):
407 print('will dump samples to', samples_dir)
408 if not tf.gfile.IsDirectory(samples_dir):
409 tf.gfile.MakeDirs(samples_dir)
410 filename = os.path.join(
411 samples_dir, 'samples_ema{}_step{:09d}.pkl'.format(int(ema), curr_step))
412 assert not tf.io.gfile.exists(filename), 'samples file already exists: {}'.format(filename)
413
414 num_gen_batches = int(np.ceil(num_samples / self.total_bs))
415 print('generating {} samples ({} batches)...'.format(num_samples, num_gen_batches))
416
417 # gen_batches = [
418 # sess.run(self.ema_samples_outputs if ema else self.samples_outputs)
419 # for _ in trange(num_gen_batches, desc='sampling')
420 # ]
421 # assert all(set(b.keys()) == set(self.samples_outputs.keys()) for b in gen_batches)
422 # concatenated = {
423 # k: np.concatenate([b[k].astype(np.float32) for b in gen_batches], axis=0)[:num_samples]
424 # for k in self.samples_outputs.keys()
425 # }
426 # assert all(len(v) == num_samples for v in concatenated.values())
427 #
428 # print('writing samples to:', filename)
429 # with tf.io.gfile.GFile(filename, 'wb') as f:
430 # f.write(pickle.dumps(concatenated, protocol=pickle.HIGHEST_PROTOCOL))
431
432 for i in trange(num_gen_batches, desc='sampling'):
433 b = sess.run(self.ema_samples_outputs if ema else self.samples_outputs)
434 assert set(b.keys()) == set(self.samples_outputs.keys())
435 b = {
436 k: b[k].astype(np.float32) for k in self.samples_outputs.keys()
437 }
438 #assert all(len(v) == num_samples for v in concatenated.values())
439
440 filename_i = "{}.batch{:05d}".format(filename, i)
441 print('writing samples for batch', i, 'to:', filename_i)
442 with tf.io.gfile.GFile(filename_i, 'wb') as f:
443 f.write(pickle.dumps(b, protocol=pickle.HIGHEST_PROTOCOL))
444 print('done writing samples')
445
446 def run(self, logdir, once: bool, skip_non_ema_pass=True, dump_samples_only=False, load_ckpt=None, samples_dir=None, seed=0):
447 """Runs the eval/sampling worker loop.

Callers 1

runMethod · 0.95

Calls 1

runMethod · 0.45

Tested by

no test coverage detected