MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / get_bn_variables

Function get_bn_variables

tensorpack/models/batch_norm.py:25–44  ·  view source on GitHub ↗
(n_out, use_scale, use_bias, beta_init, gamma_init)

Source from the content-addressed store, hash-verified

23
24
25def get_bn_variables(n_out, use_scale, use_bias, beta_init, gamma_init):
26 if use_bias:
27 beta = tf.get_variable('beta', [n_out], initializer=beta_init)
28 else:
29 beta = tf.zeros([n_out], name='beta')
30 if use_scale:
31 gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init)
32 else:
33 gamma = tf.ones([n_out], name='gamma')
34 # x * gamma + beta
35
36 moving_mean = tf.get_variable('mean/EMA', [n_out],
37 initializer=tf.constant_initializer(), trainable=False)
38 moving_var = tf.get_variable('variance/EMA', [n_out],
39 initializer=tf.constant_initializer(1.0), trainable=False)
40
41 if get_current_tower_context().is_main_training_tower:
42 for v in [moving_mean, moving_var]:
43 tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
44 return beta, gamma, moving_mean, moving_var
45
46
47def internal_update_bn_ema(xn, batch_mean, batch_var,

Callers 1

BatchNormFunction · 0.70

Calls 2

get_variableMethod · 0.80

Tested by

no test coverage detected