(
model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048,
)
| 83 | |
| 84 | |
| 85 | def evaluation( |
| 86 | model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128, |
| 87 | tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048, |
| 88 | ): |
| 89 | region = utils.get_gcp_region() |
| 90 | tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file) |
| 91 | kwargs = tpu_utils.load_train_kwargs(model_dir) |
| 92 | print('loaded kwargs:', kwargs) |
| 93 | ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file) |
| 94 | worker = tpu_utils.EvalWorker( |
| 95 | tpu_name=tpu_name, |
| 96 | model_constructor=lambda: Model( |
| 97 | model_name=kwargs['model_name'], |
| 98 | betas=get_beta_schedule( |
| 99 | kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], |
| 100 | num_diffusion_timesteps=kwargs['num_diffusion_timesteps'] |
| 101 | ), |
| 102 | loss_type=kwargs['loss_type'], |
| 103 | num_classes=ds.num_classes, |
| 104 | dropout=kwargs['dropout'], |
| 105 | randflip=kwargs['randflip'], |
| 106 | block_size=kwargs['block_size'] |
| 107 | ), |
| 108 | total_bs=total_bs, inception_bs=total_bs, num_inception_samples=num_inception_samples, |
| 109 | dataset=ds, |
| 110 | limit_dataset_size=30000 # limit size of dataset for computing Inception features, for memory reasons |
| 111 | ) |
| 112 | worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only, |
| 113 | samples_dir=samples_dir) |
| 114 | |
| 115 | |
| 116 | def train( |
nothing calls this directly
no test coverage detected