MCPcopy
hub / github.com/tensorlayer/TensorLayer / build

Method build

tensorlayer/layers/normalization.py:726–770  ·  view source on GitHub ↗
(self, inputs_shape)

Source from the content-addressed store, hash-verified

724 )
725
726 def build(self, inputs_shape):
727 # shape = inputs.get_shape().as_list()
728 if len(inputs_shape) != 4:
729 raise Exception("This GroupNorm only supports 2D images.")
730
731 if self.data_format == 'channels_last':
732 channels = inputs_shape[-1]
733 self.int_shape = tf.concat(
734 [#tf.shape(input=self.inputs)[0:3],
735 inputs_shape[0:3],
736 tf.convert_to_tensor(value=[self.groups, channels // self.groups])], axis=0
737 )
738 elif self.data_format == 'channels_first':
739 channels = inputs_shape[1]
740 self.int_shape = tf.concat(
741 [
742 # tf.shape(input=self.inputs)[0:1],
743 inputs_shape[0:1],
744 tf.convert_to_tensor(value=[self.groups, channels // self.groups]),
745 # tf.shape(input=self.inputs)[2:4]
746 inputs_shape[2:4],
747 ],
748 axis=0
749 )
750 else:
751 raise ValueError("data_format must be 'channels_last' or 'channels_first'.")
752
753 if self.groups > channels:
754 raise ValueError('Invalid groups %d for %d channels.' % (self.groups, channels))
755 if channels % self.groups != 0:
756 raise ValueError('%d channels is not commensurate with %d groups.' % (channels, self.groups))
757
758 if self.data_format == 'channels_last':
759 # mean, var = tf.nn.moments(x, [1, 2, 4], keep_dims=True)
760 self.gamma = self._get_weights("gamma", shape=channels, init=tl.initializers.ones())
761 # self.gamma = tf.compat.v1.get_variable('gamma', channels, initializer=tf.compat.v1.initializers.ones())
762 self.beta = self._get_weights("beta", shape=channels, init=tl.initializers.zeros())
763 # self.beta = tf.compat.v1.get_variable('beta', channels, initializer=tf.compat.v1.initializers.zeros())
764 elif self.data_format == 'channels_first':
765 # mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True)
766 self.gamma = self._get_weights("gamma", shape=[1, channels, 1, 1], init=tl.initializers.ones())
767 # self.gamma = tf.compat.v1.get_variable('gamma', [1, channels, 1, 1], initializer=tf.compat.v1.initializers.ones())
768 self.beta = self._get_weights("beta", shape=[1, channels, 1, 1], init=tl.initializers.zeros())
769 # self.beta = tf.compat.v1.get_variable('beta', [1, channels, 1, 1], initializer=tf.compat.v1.initializers.zeros())
770 # self.add_weights([self.gamma, self.bata])
771
772 def forward(self, inputs):
773 x = tf.reshape(inputs, self.int_shape)

Callers

nothing calls this directly

Calls 1

_get_weightsMethod · 0.80

Tested by

no test coverage detected