(features, params, mode)
| 128 | |
| 129 | # model_fn for TPUEstimator |
| 130 | def model_fn(features, params, mode): |
| 131 | local_bs = params['batch_size'] |
| 132 | print('Global batch size: {}, local batch size: {}'.format(total_bs, local_bs)) |
| 133 | assert total_bs == num_tpu_replicas() * local_bs |
| 134 | |
| 135 | assert mode == tf.estimator.ModeKeys.TRAIN, 'only TRAIN mode supported' |
| 136 | assert features['image'].shape[0] == local_bs |
| 137 | assert features['label'].shape == [local_bs] and features['label'].dtype == tf.int32 |
| 138 | # assert labels.dtype == features['label'].dtype and labels.shape == features['label'].shape |
| 139 | |
| 140 | del params |
| 141 | |
| 142 | ########## |
| 143 | |
| 144 | # create model |
| 145 | model = model_constructor() |
| 146 | assert isinstance(model, Model) |
| 147 | |
| 148 | # training loss |
| 149 | train_info_dict = model.train_fn(normalize_data(tf.cast(features['image'], tf.float32)), features['label']) |
| 150 | loss = train_info_dict['loss'] |
| 151 | assert loss.shape == [] |
| 152 | |
| 153 | # train op |
| 154 | trainable_variables = tf.trainable_variables() |
| 155 | print('num params: {:,}'.format(sum(int(np.prod(p.shape.as_list())) for p in trainable_variables))) |
| 156 | global_step = tf.train.get_or_create_global_step() |
| 157 | warmed_up_lr = utils.get_warmed_up_lr(max_lr=lr, warmup=warmup, global_step=global_step) |
| 158 | train_op, gnorm = utils.make_optimizer( |
| 159 | loss=loss, |
| 160 | trainable_variables=trainable_variables, |
| 161 | global_step=global_step, |
| 162 | lr=warmed_up_lr, |
| 163 | optimizer=optimizer, |
| 164 | grad_clip=grad_clip / float(num_tpu_replicas()), |
| 165 | tpu=True |
| 166 | ) |
| 167 | |
| 168 | # ema |
| 169 | ema, ema_op = make_ema(global_step=global_step, ema_decay=ema_decay, trainable_variables=trainable_variables) |
| 170 | with tf.control_dependencies([train_op]): |
| 171 | train_op = tf.group(ema_op) |
| 172 | |
| 173 | # summary |
| 174 | tpu_summary = TpuSummaries(model_dir, save_summary_steps=100) |
| 175 | tpu_summary.scalar('train/loss', loss) |
| 176 | tpu_summary.scalar('train/gnorm', gnorm) |
| 177 | tpu_summary.scalar('train/pnorm', utils.rms(trainable_variables)) |
| 178 | tpu_summary.scalar('train/lr', warmed_up_lr) |
| 179 | return tf.estimator.tpu.TPUEstimatorSpec( |
| 180 | mode=mode, host_call=tpu_summary.get_host_call(), loss=loss, train_op=train_op) |
| 181 | |
| 182 | # Set up Estimator and train |
| 183 | print("warm_start_from:", warm_start_from) |
nothing calls this directly
no test coverage detected