| 458 | return s.format(classname=self.__class__.__name__, **self.__dict__) |
| 459 | |
| 460 | def _get_param_shape(self, inputs_shape): |
| 461 | if self.data_format == 'channels_last': |
| 462 | axis = len(inputs_shape) - 1 |
| 463 | elif self.data_format == 'channels_first': |
| 464 | axis = 1 |
| 465 | else: |
| 466 | raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first')) |
| 467 | |
| 468 | channels = inputs_shape[axis] |
| 469 | params_shape = [1] * len(inputs_shape) |
| 470 | params_shape[axis] = channels |
| 471 | |
| 472 | axes = [i for i in range(len(inputs_shape)) if i != 0 and i != axis] |
| 473 | return params_shape, axes |
| 474 | |
| 475 | def build(self, inputs_shape): |
| 476 | params_shape, self.axes = self._get_param_shape(inputs_shape) |