(self, inputs_shape)
| 843 | ) |
| 844 | |
| 845 | def build(self, inputs_shape): |
| 846 | if len(inputs_shape) != 4: |
| 847 | raise Exception("This SwitchNorm only supports 2D images.") |
| 848 | if self.data_format != 'channels_last': |
| 849 | raise Exception("This SwitchNorm only supports channels_last.") |
| 850 | ch = inputs_shape[-1] |
| 851 | self.gamma = self._get_weights("gamma", shape=[ch], init=self.gamma_init) |
| 852 | # self.gamma = tf.compat.v1.get_variable("gamma", [ch], initializer=gamma_init) |
| 853 | self.beta = self._get_weights("beta", shape=[ch], init=self.beta_init) |
| 854 | # self.beta = tf.compat.v1.get_variable("beta", [ch], initializer=beta_init) |
| 855 | |
| 856 | self.mean_weight_var = self._get_weights("mean_weight", shape=[3], init=tl.initializers.constant(1.0)) |
| 857 | # self.mean_weight_var = tf.compat.v1.get_variable("mean_weight", [3], initializer=tf.compat.v1.initializers.constant(1.0)) |
| 858 | self.var_weight_var = self._get_weights("var_weight", shape=[3], init=tl.initializers.constant(1.0)) |
| 859 | # self.var_weight_var = tf.compat.v1.get_variable("var_weight", [3], initializer=tf.compat.v1.initializers.constant(1.0)) |
| 860 | |
| 861 | # self.add_weights([self.gamma, self.beta, self.mean_weight_var, self.var_weight_var]) |
| 862 | |
| 863 | def forward(self, inputs): |
| 864 |
nothing calls this directly
no test coverage detected