| 111 | |
| 112 | |
| 113 | def evaluation( # evaluation loop for use during training |
| 114 | model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=256, |
| 115 | tfds_data_dir='tensorflow_datasets', load_ckpt=None |
| 116 | ): |
| 117 | region = utils.get_gcp_region() |
| 118 | tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) |
| 119 | kwargs = tpu_utils.load_train_kwargs(model_dir) |
| 120 | print('loaded kwargs:', kwargs) |
| 121 | ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) |
| 122 | worker = tpu_utils.EvalWorker( |
| 123 | tpu_name=tpu_name, |
| 124 | model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds), |
| 125 | total_bs=total_bs, inception_bs=total_bs, num_inception_samples=50000, |
| 126 | dataset=ds, |
| 127 | ) |
| 128 | worker.run( |
| 129 | logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only, load_ckpt=load_ckpt) |
| 130 | |
| 131 | |
| 132 | def train( |