(n_out, use_scale, use_bias, beta_init, gamma_init)
| 23 | |
| 24 | |
| 25 | def get_bn_variables(n_out, use_scale, use_bias, beta_init, gamma_init): |
| 26 | if use_bias: |
| 27 | beta = tf.get_variable('beta', [n_out], initializer=beta_init) |
| 28 | else: |
| 29 | beta = tf.zeros([n_out], name='beta') |
| 30 | if use_scale: |
| 31 | gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init) |
| 32 | else: |
| 33 | gamma = tf.ones([n_out], name='gamma') |
| 34 | # x * gamma + beta |
| 35 | |
| 36 | moving_mean = tf.get_variable('mean/EMA', [n_out], |
| 37 | initializer=tf.constant_initializer(), trainable=False) |
| 38 | moving_var = tf.get_variable('variance/EMA', [n_out], |
| 39 | initializer=tf.constant_initializer(1.0), trainable=False) |
| 40 | |
| 41 | if get_current_tower_context().is_main_training_tower: |
| 42 | for v in [moving_mean, moving_var]: |
| 43 | tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v) |
| 44 | return beta, gamma, moving_mean, moving_var |
| 45 | |
| 46 | |
| 47 | def internal_update_bn_ema(xn, batch_mean, batch_var, |
no test coverage detected