(self, inputs_shape)
| 724 | ) |
| 725 | |
| 726 | def build(self, inputs_shape): |
| 727 | # shape = inputs.get_shape().as_list() |
| 728 | if len(inputs_shape) != 4: |
| 729 | raise Exception("This GroupNorm only supports 2D images.") |
| 730 | |
| 731 | if self.data_format == 'channels_last': |
| 732 | channels = inputs_shape[-1] |
| 733 | self.int_shape = tf.concat( |
| 734 | [#tf.shape(input=self.inputs)[0:3], |
| 735 | inputs_shape[0:3], |
| 736 | tf.convert_to_tensor(value=[self.groups, channels // self.groups])], axis=0 |
| 737 | ) |
| 738 | elif self.data_format == 'channels_first': |
| 739 | channels = inputs_shape[1] |
| 740 | self.int_shape = tf.concat( |
| 741 | [ |
| 742 | # tf.shape(input=self.inputs)[0:1], |
| 743 | inputs_shape[0:1], |
| 744 | tf.convert_to_tensor(value=[self.groups, channels // self.groups]), |
| 745 | # tf.shape(input=self.inputs)[2:4] |
| 746 | inputs_shape[2:4], |
| 747 | ], |
| 748 | axis=0 |
| 749 | ) |
| 750 | else: |
| 751 | raise ValueError("data_format must be 'channels_last' or 'channels_first'.") |
| 752 | |
| 753 | if self.groups > channels: |
| 754 | raise ValueError('Invalid groups %d for %d channels.' % (self.groups, channels)) |
| 755 | if channels % self.groups != 0: |
| 756 | raise ValueError('%d channels is not commensurate with %d groups.' % (channels, self.groups)) |
| 757 | |
| 758 | if self.data_format == 'channels_last': |
| 759 | # mean, var = tf.nn.moments(x, [1, 2, 4], keep_dims=True) |
| 760 | self.gamma = self._get_weights("gamma", shape=channels, init=tl.initializers.ones()) |
| 761 | # self.gamma = tf.compat.v1.get_variable('gamma', channels, initializer=tf.compat.v1.initializers.ones()) |
| 762 | self.beta = self._get_weights("beta", shape=channels, init=tl.initializers.zeros()) |
| 763 | # self.beta = tf.compat.v1.get_variable('beta', channels, initializer=tf.compat.v1.initializers.zeros()) |
| 764 | elif self.data_format == 'channels_first': |
| 765 | # mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True) |
| 766 | self.gamma = self._get_weights("gamma", shape=[1, channels, 1, 1], init=tl.initializers.ones()) |
| 767 | # self.gamma = tf.compat.v1.get_variable('gamma', [1, channels, 1, 1], initializer=tf.compat.v1.initializers.ones()) |
| 768 | self.beta = self._get_weights("beta", shape=[1, channels, 1, 1], init=tl.initializers.zeros()) |
| 769 | # self.beta = tf.compat.v1.get_variable('beta', [1, channels, 1, 1], initializer=tf.compat.v1.initializers.zeros()) |
| 770 | # self.add_weights([self.gamma, self.bata]) |
| 771 | |
| 772 | def forward(self, inputs): |
| 773 | x = tf.reshape(inputs, self.int_shape) |
nothing calls this directly
no test coverage detected