(
exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='lsun',
optimizer='adam', total_bs=64, grad_clip=1., lr=2e-5, warmup=5000,
num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
dropout=0.0, randflip=1, block_size=1,
tfr_file='tensorflow_datasets/lsun/church/church-r08.tfrecords', log_dir='logs',
warm_start_model_dir=None
)
| 114 | |
| 115 | |
| 116 | def train( |
| 117 | exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='lsun', |
| 118 | optimizer='adam', total_bs=64, grad_clip=1., lr=2e-5, warmup=5000, |
| 119 | num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred', |
| 120 | dropout=0.0, randflip=1, block_size=1, |
| 121 | tfr_file='tensorflow_datasets/lsun/church/church-r08.tfrecords', log_dir='logs', |
| 122 | warm_start_model_dir=None |
| 123 | ): |
| 124 | region = utils.get_gcp_region() |
| 125 | tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file) |
| 126 | log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir) |
| 127 | print("tfr_file:", tfr_file) |
| 128 | print("log_dir:", log_dir) |
| 129 | kwargs = dict(locals()) |
| 130 | ds = datasets.get_dataset(dataset, tfr_file=tfr_file) |
| 131 | tpu_utils.run_training( |
| 132 | date_str='9999-99-99', |
| 133 | exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format( |
| 134 | **kwargs), |
| 135 | model_constructor=lambda: Model( |
| 136 | model_name=model_name, |
| 137 | betas=get_beta_schedule( |
| 138 | beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps |
| 139 | ), |
| 140 | loss_type=loss_type, |
| 141 | num_classes=ds.num_classes, |
| 142 | dropout=dropout, |
| 143 | randflip=randflip, |
| 144 | block_size=block_size |
| 145 | ), |
| 146 | optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, |
| 147 | train_input_fn=ds.train_input_fn, |
| 148 | tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs, |
| 149 | warm_start_from=tf.estimator.WarmStartSettings( |
| 150 | ckpt_to_initialize_from=tf.train.latest_checkpoint(warm_start_model_dir), |
| 151 | vars_to_warm_start=[".*"] |
| 152 | ) if warm_start_model_dir else None |
| 153 | ) |
| 154 | |
| 155 | |
| 156 | if __name__ == '__main__': |
nothing calls this directly
no test coverage detected