MCPcopy
hub / github.com/tensorlayer/TensorLayer / build

Method build

tensorlayer/layers/normalization.py:845–861  ·  view source on GitHub ↗
(self, inputs_shape)

Source from the content-addressed store, hash-verified

843 )
844
845 def build(self, inputs_shape):
846 if len(inputs_shape) != 4:
847 raise Exception("This SwitchNorm only supports 2D images.")
848 if self.data_format != 'channels_last':
849 raise Exception("This SwitchNorm only supports channels_last.")
850 ch = inputs_shape[-1]
851 self.gamma = self._get_weights("gamma", shape=[ch], init=self.gamma_init)
852 # self.gamma = tf.compat.v1.get_variable("gamma", [ch], initializer=gamma_init)
853 self.beta = self._get_weights("beta", shape=[ch], init=self.beta_init)
854 # self.beta = tf.compat.v1.get_variable("beta", [ch], initializer=beta_init)
855
856 self.mean_weight_var = self._get_weights("mean_weight", shape=[3], init=tl.initializers.constant(1.0))
857 # self.mean_weight_var = tf.compat.v1.get_variable("mean_weight", [3], initializer=tf.compat.v1.initializers.constant(1.0))
858 self.var_weight_var = self._get_weights("var_weight", shape=[3], init=tl.initializers.constant(1.0))
859 # self.var_weight_var = tf.compat.v1.get_variable("var_weight", [3], initializer=tf.compat.v1.initializers.constant(1.0))
860
861 # self.add_weights([self.gamma, self.beta, self.mean_weight_var, self.var_weight_var])
862
863 def forward(self, inputs):
864

Callers

nothing calls this directly

Calls 1

_get_weightsMethod · 0.80

Tested by

no test coverage detected