(
exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256',
optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000,
num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
dropout=0.0, randflip=1, block_size=1,
tfds_data_dir='tensorflow_datasets', log_dir='logs'
)
| 130 | |
| 131 | |
| 132 | def train( |
| 133 | exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256', |
| 134 | optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000, |
| 135 | num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred', |
| 136 | dropout=0.0, randflip=1, block_size=1, |
| 137 | tfds_data_dir='tensorflow_datasets', log_dir='logs' |
| 138 | ): |
| 139 | region = utils.get_gcp_region() |
| 140 | tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) |
| 141 | log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir) |
| 142 | kwargs = dict(locals()) |
| 143 | ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir) |
| 144 | tpu_utils.run_training( |
| 145 | date_str='9999-99-99', |
| 146 | exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format( |
| 147 | **kwargs), |
| 148 | model_constructor=lambda: Model( |
| 149 | model_name=model_name, |
| 150 | betas=get_beta_schedule( |
| 151 | beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps |
| 152 | ), |
| 153 | loss_type=loss_type, |
| 154 | num_classes=ds.num_classes, |
| 155 | dropout=dropout, |
| 156 | randflip=randflip, |
| 157 | block_size=block_size |
| 158 | ), |
| 159 | optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, |
| 160 | train_input_fn=ds.train_input_fn, |
| 161 | tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs |
| 162 | ) |
| 163 | |
| 164 | |
| 165 | if __name__ == '__main__': |
nothing calls this directly
no test coverage detected