(self, inputs_shape)
| 171 | return s.format(classname=self.__class__.__name__, **self.__dict__) |
| 172 | |
| 173 | def build(self, inputs_shape): |
| 174 | if self.channel_shared: |
| 175 | w_shape = (1, ) |
| 176 | else: |
| 177 | w_shape = (inputs_shape[-1], ) |
| 178 | self.alpha_var = self._get_weights("alpha", shape=w_shape, init=self.a_init) |
| 179 | self.alpha_var_constrained = tf.nn.sigmoid(self.alpha_var, name="constraining_alpha_var_in_0_1") |
| 180 | |
| 181 | # @tf.function |
| 182 | def forward(self, inputs): |