MCPcopy
hub / github.com/hojonathanho/diffusion / train

Function train

scripts/run_celebahq.py:132–162  ·  view source on GitHub ↗
(
    exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256',
    optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000,
    num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
    dropout=0.0, randflip=1, block_size=1,
    tfds_data_dir='tensorflow_datasets', log_dir='logs'
)

Source from the content-addressed store, hash-verified

130
131
132def train(
133 exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256',
134 optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000,
135 num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
136 dropout=0.0, randflip=1, block_size=1,
137 tfds_data_dir='tensorflow_datasets', log_dir='logs'
138):
139 region = utils.get_gcp_region()
140 tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir)
141 log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir)
142 kwargs = dict(locals())
143 ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir)
144 tpu_utils.run_training(
145 date_str='9999-99-99',
146 exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format(
147 **kwargs),
148 model_constructor=lambda: Model(
149 model_name=model_name,
150 betas=get_beta_schedule(
151 beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps
152 ),
153 loss_type=loss_type,
154 num_classes=ds.num_classes,
155 dropout=dropout,
156 randflip=randflip,
157 block_size=block_size
158 ),
159 optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip,
160 train_input_fn=ds.train_input_fn,
161 tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs
162 )
163
164
165if __name__ == '__main__':

Callers

nothing calls this directly

Calls 2

get_beta_scheduleFunction · 0.90
ModelClass · 0.70

Tested by

no test coverage detected