| 259 | return s.format(classname=self.__class__.__name__, **self.__dict__) |
| 260 | |
| 261 | def build(self, inputs_shape): |
| 262 | if self.channel_shared: |
| 263 | w_shape = (1, ) |
| 264 | else: |
| 265 | w_shape = (inputs_shape[-1], ) |
| 266 | |
| 267 | # Alpha for outputs lower than zeros |
| 268 | self.alpha_low = self._get_weights("alpha_low", shape=w_shape, init=self.a_init) |
| 269 | self.alpha_low_constrained = tf.nn.sigmoid(self.alpha_low, name="constraining_alpha_low_in_0_1") |
| 270 | |
| 271 | # Alpha for outputs higher than 6 |
| 272 | self.alpha_high = self._get_weights("alpha_high", shape=w_shape, init=self.a_init) |
| 273 | self.alpha_high_constrained = tf.nn.sigmoid(self.alpha_high, name="constraining_alpha_high_in_0_1") |
| 274 | |
| 275 | # @tf.function |
| 276 | def forward(self, inputs): |