(
model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
tfds_data_dir='tensorflow_datasets',
)
| 101 | |
| 102 | |
| 103 | def evaluation( |
| 104 | model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128, |
| 105 | tfds_data_dir='tensorflow_datasets', |
| 106 | ): |
| 107 | region = utils.get_gcp_region() |
| 108 | tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) |
| 109 | kwargs = tpu_utils.load_train_kwargs(model_dir) |
| 110 | print('loaded kwargs:', kwargs) |
| 111 | ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) |
| 112 | worker = tpu_utils.EvalWorker( |
| 113 | tpu_name=tpu_name, |
| 114 | model_constructor=lambda: Model( |
| 115 | model_name=kwargs['model_name'], |
| 116 | betas=get_beta_schedule( |
| 117 | kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], |
| 118 | num_diffusion_timesteps=kwargs['num_diffusion_timesteps'] |
| 119 | ), |
| 120 | loss_type=kwargs['loss_type'], |
| 121 | num_classes=ds.num_classes, |
| 122 | dropout=kwargs['dropout'], |
| 123 | randflip=kwargs['randflip'], |
| 124 | block_size=kwargs['block_size'] |
| 125 | ), |
| 126 | total_bs=total_bs, inception_bs=total_bs, num_inception_samples=2048, |
| 127 | dataset=ds, |
| 128 | ) |
| 129 | worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only) |
| 130 | |
| 131 | |
| 132 | def train( |
nothing calls this directly
no test coverage detected