MCPcopy
hub / github.com/hojonathanho/diffusion / train

Function train

scripts/run_lsun.py:116–153  ·  view source on GitHub ↗
(
    exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='lsun',
    optimizer='adam', total_bs=64, grad_clip=1., lr=2e-5, warmup=5000,
    num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
    dropout=0.0, randflip=1, block_size=1,
    tfr_file='tensorflow_datasets/lsun/church/church-r08.tfrecords', log_dir='logs',
    warm_start_model_dir=None
)

Source from the content-addressed store, hash-verified

114
115
116def train(
117 exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='lsun',
118 optimizer='adam', total_bs=64, grad_clip=1., lr=2e-5, warmup=5000,
119 num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
120 dropout=0.0, randflip=1, block_size=1,
121 tfr_file='tensorflow_datasets/lsun/church/church-r08.tfrecords', log_dir='logs',
122 warm_start_model_dir=None
123):
124 region = utils.get_gcp_region()
125 tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file)
126 log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir)
127 print("tfr_file:", tfr_file)
128 print("log_dir:", log_dir)
129 kwargs = dict(locals())
130 ds = datasets.get_dataset(dataset, tfr_file=tfr_file)
131 tpu_utils.run_training(
132 date_str='9999-99-99',
133 exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format(
134 **kwargs),
135 model_constructor=lambda: Model(
136 model_name=model_name,
137 betas=get_beta_schedule(
138 beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps
139 ),
140 loss_type=loss_type,
141 num_classes=ds.num_classes,
142 dropout=dropout,
143 randflip=randflip,
144 block_size=block_size
145 ),
146 optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip,
147 train_input_fn=ds.train_input_fn,
148 tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs,
149 warm_start_from=tf.estimator.WarmStartSettings(
150 ckpt_to_initialize_from=tf.train.latest_checkpoint(warm_start_model_dir),
151 vars_to_warm_start=[".*"]
152 ) if warm_start_model_dir else None
153 )
154
155
156if __name__ == '__main__':

Callers

nothing calls this directly

Calls 2

get_beta_scheduleFunction · 0.90
ModelClass · 0.70

Tested by

no test coverage detected