MCPcopy
hub / github.com/hojonathanho/diffusion / evaluation

Function evaluation

scripts/run_lsun.py:85–113  ·  view source on GitHub ↗
(
    model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
    tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048,
)

Source from the content-addressed store, hash-verified

83
84
85def evaluation(
86 model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
87 tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048,
88):
89 region = utils.get_gcp_region()
90 tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file)
91 kwargs = tpu_utils.load_train_kwargs(model_dir)
92 print('loaded kwargs:', kwargs)
93 ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file)
94 worker = tpu_utils.EvalWorker(
95 tpu_name=tpu_name,
96 model_constructor=lambda: Model(
97 model_name=kwargs['model_name'],
98 betas=get_beta_schedule(
99 kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'],
100 num_diffusion_timesteps=kwargs['num_diffusion_timesteps']
101 ),
102 loss_type=kwargs['loss_type'],
103 num_classes=ds.num_classes,
104 dropout=kwargs['dropout'],
105 randflip=kwargs['randflip'],
106 block_size=kwargs['block_size']
107 ),
108 total_bs=total_bs, inception_bs=total_bs, num_inception_samples=num_inception_samples,
109 dataset=ds,
110 limit_dataset_size=30000 # limit size of dataset for computing Inception features, for memory reasons
111 )
112 worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only,
113 samples_dir=samples_dir)
114
115
116def train(

Callers

nothing calls this directly

Calls 3

runMethod · 0.95
get_beta_scheduleFunction · 0.90
ModelClass · 0.70

Tested by

no test coverage detected