| 404 | log.flush() |
| 405 | |
| 406 | def _dump_samples(self, sess, curr_step, samples_dir, ema: bool, num_samples=50000): |
| 407 | print('will dump samples to', samples_dir) |
| 408 | if not tf.gfile.IsDirectory(samples_dir): |
| 409 | tf.gfile.MakeDirs(samples_dir) |
| 410 | filename = os.path.join( |
| 411 | samples_dir, 'samples_ema{}_step{:09d}.pkl'.format(int(ema), curr_step)) |
| 412 | assert not tf.io.gfile.exists(filename), 'samples file already exists: {}'.format(filename) |
| 413 | |
| 414 | num_gen_batches = int(np.ceil(num_samples / self.total_bs)) |
| 415 | print('generating {} samples ({} batches)...'.format(num_samples, num_gen_batches)) |
| 416 | |
| 417 | # gen_batches = [ |
| 418 | # sess.run(self.ema_samples_outputs if ema else self.samples_outputs) |
| 419 | # for _ in trange(num_gen_batches, desc='sampling') |
| 420 | # ] |
| 421 | # assert all(set(b.keys()) == set(self.samples_outputs.keys()) for b in gen_batches) |
| 422 | # concatenated = { |
| 423 | # k: np.concatenate([b[k].astype(np.float32) for b in gen_batches], axis=0)[:num_samples] |
| 424 | # for k in self.samples_outputs.keys() |
| 425 | # } |
| 426 | # assert all(len(v) == num_samples for v in concatenated.values()) |
| 427 | # |
| 428 | # print('writing samples to:', filename) |
| 429 | # with tf.io.gfile.GFile(filename, 'wb') as f: |
| 430 | # f.write(pickle.dumps(concatenated, protocol=pickle.HIGHEST_PROTOCOL)) |
| 431 | |
| 432 | for i in trange(num_gen_batches, desc='sampling'): |
| 433 | b = sess.run(self.ema_samples_outputs if ema else self.samples_outputs) |
| 434 | assert set(b.keys()) == set(self.samples_outputs.keys()) |
| 435 | b = { |
| 436 | k: b[k].astype(np.float32) for k in self.samples_outputs.keys() |
| 437 | } |
| 438 | #assert all(len(v) == num_samples for v in concatenated.values()) |
| 439 | |
| 440 | filename_i = "{}.batch{:05d}".format(filename, i) |
| 441 | print('writing samples for batch', i, 'to:', filename_i) |
| 442 | with tf.io.gfile.GFile(filename_i, 'wb') as f: |
| 443 | f.write(pickle.dumps(b, protocol=pickle.HIGHEST_PROTOCOL)) |
| 444 | print('done writing samples') |
| 445 | |
| 446 | def run(self, logdir, once: bool, skip_non_ema_pass=True, dump_samples_only=False, load_ckpt=None, samples_dir=None, seed=0): |
| 447 | """Runs the eval/sampling worker loop. |