MCPcopy
hub / github.com/hojonathanho/diffusion / evaluation

Function evaluation

scripts/run_cifar.py:113–129  ·  view source on GitHub ↗
(  # evaluation loop for use during training
    model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=256,
    tfds_data_dir='tensorflow_datasets', load_ckpt=None
)

Source from the content-addressed store, hash-verified

111
112
113def evaluation( # evaluation loop for use during training
114 model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=256,
115 tfds_data_dir='tensorflow_datasets', load_ckpt=None
116):
117 region = utils.get_gcp_region()
118 tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir)
119 kwargs = tpu_utils.load_train_kwargs(model_dir)
120 print('loaded kwargs:', kwargs)
121 ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir)
122 worker = tpu_utils.EvalWorker(
123 tpu_name=tpu_name,
124 model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds),
125 total_bs=total_bs, inception_bs=total_bs, num_inception_samples=50000,
126 dataset=ds,
127 )
128 worker.run(
129 logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only, load_ckpt=load_ckpt)
130
131
132def train(

Callers

nothing calls this directly

Calls 1

runMethod · 0.95

Tested by

no test coverage detected