MCPcopy
hub / github.com/hojonathanho/diffusion / evaluation

Function evaluation

scripts/run_celebahq.py:103–129  ·  view source on GitHub ↗
(
    model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
    tfds_data_dir='tensorflow_datasets',
)

Source from the content-addressed store, hash-verified

101
102
103def evaluation(
104 model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
105 tfds_data_dir='tensorflow_datasets',
106):
107 region = utils.get_gcp_region()
108 tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir)
109 kwargs = tpu_utils.load_train_kwargs(model_dir)
110 print('loaded kwargs:', kwargs)
111 ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir)
112 worker = tpu_utils.EvalWorker(
113 tpu_name=tpu_name,
114 model_constructor=lambda: Model(
115 model_name=kwargs['model_name'],
116 betas=get_beta_schedule(
117 kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'],
118 num_diffusion_timesteps=kwargs['num_diffusion_timesteps']
119 ),
120 loss_type=kwargs['loss_type'],
121 num_classes=ds.num_classes,
122 dropout=kwargs['dropout'],
123 randflip=kwargs['randflip'],
124 block_size=kwargs['block_size']
125 ),
126 total_bs=total_bs, inception_bs=total_bs, num_inception_samples=2048,
127 dataset=ds,
128 )
129 worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only)
130
131
132def train(

Callers

nothing calls this directly

Calls 3

runMethod · 0.95
get_beta_scheduleFunction · 0.90
ModelClass · 0.70

Tested by

no test coverage detected