MCPcopy
hub / github.com/hojonathanho/diffusion / __init__

Method __init__

diffusion_tf/tpu_utils/tpu_utils.py:256–308  ·  view source on GitHub ↗
(self, tpu_name, model_constructor, total_bs, dataset, inception_bs=8, num_inception_samples=1024, limit_dataset_size=0)

Source from the content-addressed store, hash-verified

254
255class EvalWorker:
256 def __init__(self, tpu_name, model_constructor, total_bs, dataset, inception_bs=8, num_inception_samples=1024, limit_dataset_size=0):
257
258 self.resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name)
259 tf.tpu.experimental.initialize_tpu_system(self.resolver)
260 self.strategy = tf.distribute.experimental.TPUStrategy(self.resolver)
261
262 self.num_cores = self.strategy.num_replicas_in_sync
263 assert total_bs % self.num_cores == 0
264 self.total_bs = total_bs
265 self.local_bs = total_bs // self.num_cores
266 print('num cores: {}'.format(self.num_cores))
267 print('total batch size: {}'.format(self.total_bs))
268 print('local batch size: {}'.format(self.local_bs))
269 self.num_inception_samples = num_inception_samples
270 assert inception_bs % self.num_cores == 0
271 self.inception_bs = inception_bs
272 self.inception_local_bs = inception_bs // self.num_cores
273 self.dataset = dataset
274 assert dataset.num_classes == 1, 'not supported'
275
276 # TPU context
277 with self.strategy.scope():
278 # Inception network on real data
279 print('===== INCEPTION =====')
280 # Eval dataset iterator (this is the training set without repeat & shuffling)
281 self.inception_real_train = InceptionFeatures(
282 dataset=dataset.train_one_pass_input_fn(params={'batch_size': total_bs}), strategy=self.strategy, limit_dataset_size=limit_dataset_size // total_bs)
283 # Val dataset, if it exists
284 val_ds = dataset.eval_input_fn(params={'batch_size': total_bs})
285 self.inception_real_val = None if val_ds is None else InceptionFeatures(dataset=val_ds, strategy=self.strategy, limit_dataset_size=limit_dataset_size // total_bs)
286
287 img_batch_shape = self.inception_real_train.ds_iterator.output_shapes['image'].as_list()
288 assert img_batch_shape[0] == self.local_bs
289
290 # Model
291 self.model = model_constructor()
292 assert isinstance(self.model, Model)
293
294 # Eval/samples graphs
295 print('===== SAMPLES =====')
296 self.samples_outputs, self.samples_inception = self._make_sampling_graph(
297 img_shape=img_batch_shape[1:], with_inception=True)
298
299 # Model with EMA parameters
300 self.global_step = tf.train.get_or_create_global_step()
301 print('===== EMA =====')
302 ema, _ = make_ema(global_step=self.global_step, ema_decay=1e-10, trainable_variables=tf.trainable_variables())
303
304 # EMA versions of the above
305 with utils.ema_scope(ema):
306 print('===== EMA SAMPLES =====')
307 self.ema_samples_outputs, self.ema_samples_inception = self._make_sampling_graph(
308 img_shape=img_batch_shape[1:], with_inception=True)
309
310 def _make_sampling_graph(self, img_shape, with_inception):
311

Callers

nothing calls this directly

Calls 5

_make_sampling_graphMethod · 0.95
InceptionFeaturesClass · 0.85
make_emaFunction · 0.85
eval_input_fnMethod · 0.45

Tested by

no test coverage detected