MCPcopy
hub / github.com/hojonathanho/diffusion / model_fn

Function model_fn

diffusion_tf/tpu_utils/tpu_utils.py:130–180  ·  view source on GitHub ↗
(features, params, mode)

Source from the content-addressed store, hash-verified

128
129 # model_fn for TPUEstimator
130 def model_fn(features, params, mode):
131 local_bs = params['batch_size']
132 print('Global batch size: {}, local batch size: {}'.format(total_bs, local_bs))
133 assert total_bs == num_tpu_replicas() * local_bs
134
135 assert mode == tf.estimator.ModeKeys.TRAIN, 'only TRAIN mode supported'
136 assert features['image'].shape[0] == local_bs
137 assert features['label'].shape == [local_bs] and features['label'].dtype == tf.int32
138 # assert labels.dtype == features['label'].dtype and labels.shape == features['label'].shape
139
140 del params
141
142 ##########
143
144 # create model
145 model = model_constructor()
146 assert isinstance(model, Model)
147
148 # training loss
149 train_info_dict = model.train_fn(normalize_data(tf.cast(features['image'], tf.float32)), features['label'])
150 loss = train_info_dict['loss']
151 assert loss.shape == []
152
153 # train op
154 trainable_variables = tf.trainable_variables()
155 print('num params: {:,}'.format(sum(int(np.prod(p.shape.as_list())) for p in trainable_variables)))
156 global_step = tf.train.get_or_create_global_step()
157 warmed_up_lr = utils.get_warmed_up_lr(max_lr=lr, warmup=warmup, global_step=global_step)
158 train_op, gnorm = utils.make_optimizer(
159 loss=loss,
160 trainable_variables=trainable_variables,
161 global_step=global_step,
162 lr=warmed_up_lr,
163 optimizer=optimizer,
164 grad_clip=grad_clip / float(num_tpu_replicas()),
165 tpu=True
166 )
167
168 # ema
169 ema, ema_op = make_ema(global_step=global_step, ema_decay=ema_decay, trainable_variables=trainable_variables)
170 with tf.control_dependencies([train_op]):
171 train_op = tf.group(ema_op)
172
173 # summary
174 tpu_summary = TpuSummaries(model_dir, save_summary_steps=100)
175 tpu_summary.scalar('train/loss', loss)
176 tpu_summary.scalar('train/gnorm', gnorm)
177 tpu_summary.scalar('train/pnorm', utils.rms(trainable_variables))
178 tpu_summary.scalar('train/lr', warmed_up_lr)
179 return tf.estimator.tpu.TPUEstimatorSpec(
180 mode=mode, host_call=tpu_summary.get_host_call(), loss=loss, train_op=train_op)
181
182 # Set up Estimator and train
183 print("warm_start_from:", warm_start_from)

Callers

nothing calls this directly

Calls 6

scalarMethod · 0.95
get_host_callMethod · 0.95
num_tpu_replicasFunction · 0.85
make_emaFunction · 0.85
TpuSummariesClass · 0.85
train_fnMethod · 0.45

Tested by

no test coverage detected