(self, tpu_name, model_constructor, total_bs, dataset, inception_bs=8, num_inception_samples=1024, limit_dataset_size=0)
| 254 | |
| 255 | class EvalWorker: |
| 256 | def __init__(self, tpu_name, model_constructor, total_bs, dataset, inception_bs=8, num_inception_samples=1024, limit_dataset_size=0): |
| 257 | |
| 258 | self.resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name) |
| 259 | tf.tpu.experimental.initialize_tpu_system(self.resolver) |
| 260 | self.strategy = tf.distribute.experimental.TPUStrategy(self.resolver) |
| 261 | |
| 262 | self.num_cores = self.strategy.num_replicas_in_sync |
| 263 | assert total_bs % self.num_cores == 0 |
| 264 | self.total_bs = total_bs |
| 265 | self.local_bs = total_bs // self.num_cores |
| 266 | print('num cores: {}'.format(self.num_cores)) |
| 267 | print('total batch size: {}'.format(self.total_bs)) |
| 268 | print('local batch size: {}'.format(self.local_bs)) |
| 269 | self.num_inception_samples = num_inception_samples |
| 270 | assert inception_bs % self.num_cores == 0 |
| 271 | self.inception_bs = inception_bs |
| 272 | self.inception_local_bs = inception_bs // self.num_cores |
| 273 | self.dataset = dataset |
| 274 | assert dataset.num_classes == 1, 'not supported' |
| 275 | |
| 276 | # TPU context |
| 277 | with self.strategy.scope(): |
| 278 | # Inception network on real data |
| 279 | print('===== INCEPTION =====') |
| 280 | # Eval dataset iterator (this is the training set without repeat & shuffling) |
| 281 | self.inception_real_train = InceptionFeatures( |
| 282 | dataset=dataset.train_one_pass_input_fn(params={'batch_size': total_bs}), strategy=self.strategy, limit_dataset_size=limit_dataset_size // total_bs) |
| 283 | # Val dataset, if it exists |
| 284 | val_ds = dataset.eval_input_fn(params={'batch_size': total_bs}) |
| 285 | self.inception_real_val = None if val_ds is None else InceptionFeatures(dataset=val_ds, strategy=self.strategy, limit_dataset_size=limit_dataset_size // total_bs) |
| 286 | |
| 287 | img_batch_shape = self.inception_real_train.ds_iterator.output_shapes['image'].as_list() |
| 288 | assert img_batch_shape[0] == self.local_bs |
| 289 | |
| 290 | # Model |
| 291 | self.model = model_constructor() |
| 292 | assert isinstance(self.model, Model) |
| 293 | |
| 294 | # Eval/samples graphs |
| 295 | print('===== SAMPLES =====') |
| 296 | self.samples_outputs, self.samples_inception = self._make_sampling_graph( |
| 297 | img_shape=img_batch_shape[1:], with_inception=True) |
| 298 | |
| 299 | # Model with EMA parameters |
| 300 | self.global_step = tf.train.get_or_create_global_step() |
| 301 | print('===== EMA =====') |
| 302 | ema, _ = make_ema(global_step=self.global_step, ema_decay=1e-10, trainable_variables=tf.trainable_variables()) |
| 303 | |
| 304 | # EMA versions of the above |
| 305 | with utils.ema_scope(ema): |
| 306 | print('===== EMA SAMPLES =====') |
| 307 | self.ema_samples_outputs, self.ema_samples_inception = self._make_sampling_graph( |
| 308 | img_shape=img_batch_shape[1:], with_inception=True) |
| 309 | |
| 310 | def _make_sampling_graph(self, img_shape, with_inception): |
| 311 |
nothing calls this directly
no test coverage detected