MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / _get_param_shape

Method _get_param_shape

tensorlayer/layers/normalization.py:460–473  ·  view source on GitHub ↗
(self, inputs_shape)

Source from the content-addressed store, hash-verified

458 return s.format(classname=self.__class__.__name__, **self.__dict__)
459
460 def _get_param_shape(self, inputs_shape):
461 if self.data_format == 'channels_last':
462 axis = len(inputs_shape) - 1
463 elif self.data_format == 'channels_first':
464 axis = 1
465 else:
466 raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
467
468 channels = inputs_shape[axis]
469 params_shape = [1] * len(inputs_shape)
470 params_shape[axis] = channels
471
472 axes = [i for i in range(len(inputs_shape)) if i != 0 and i != axis]
473 return params_shape, axes
474
475 def build(self, inputs_shape):
476 params_shape, self.axes = self._get_param_shape(inputs_shape)

Callers 1

buildMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected