The :class:`GlobalMeanPool2d` class is a 2D Global Mean Pooling layer. Parameters ------------ data_format : str One of channels_last (default, [batch, height, width, channel]) or channels_first. The ordering of the dimensions in the inputs. name : None or str A uniq
| 747 | |
| 748 | |
| 749 | class GlobalMeanPool2d(Layer): |
| 750 | """The :class:`GlobalMeanPool2d` class is a 2D Global Mean Pooling layer. |
| 751 | |
| 752 | Parameters |
| 753 | ------------ |
| 754 | data_format : str |
| 755 | One of channels_last (default, [batch, height, width, channel]) or channels_first. The ordering of the dimensions in the inputs. |
| 756 | name : None or str |
| 757 | A unique layer name. |
| 758 | |
| 759 | Examples |
| 760 | --------- |
| 761 | With TensorLayer |
| 762 | |
| 763 | >>> net = tl.layers.Input([None, 100, 100, 30], name='input') |
| 764 | >>> net = tl.layers.GlobalMeanPool2d()(net) |
| 765 | >>> output shape : [None, 30] |
| 766 | |
| 767 | """ |
| 768 | |
| 769 | def __init__( |
| 770 | self, |
| 771 | data_format='channels_last', |
| 772 | name=None # 'globalmeanpool2d' |
| 773 | ): |
| 774 | super().__init__(name) |
| 775 | |
| 776 | self.data_format = data_format |
| 777 | |
| 778 | self.build() |
| 779 | self._built = True |
| 780 | |
| 781 | logging.info("GlobalMeanPool2d %s" % self.name) |
| 782 | |
| 783 | def __repr__(self): |
| 784 | s = '{classname}(' |
| 785 | if self.name is not None: |
| 786 | s += 'name=\'{name}\'' |
| 787 | s += ')' |
| 788 | return s.format(classname=self.__class__.__name__, **self.__dict__) |
| 789 | |
| 790 | def build(self, inputs_shape=None): |
| 791 | pass |
| 792 | |
| 793 | def forward(self, inputs): |
| 794 | if self.data_format == 'channels_last': |
| 795 | outputs = tf.reduce_mean(input_tensor=inputs, axis=[1, 2], name=self.name) |
| 796 | elif self.data_format == 'channels_first': |
| 797 | outputs = tf.reduce_mean(input_tensor=inputs, axis=[2, 3], name=self.name) |
| 798 | else: |
| 799 | raise ValueError( |
| 800 | "`data_format` should have one of the following values: [`channels_last`, `channels_first`]" |
| 801 | ) |
| 802 | return outputs |
| 803 | |
| 804 | |
| 805 | class GlobalMaxPool3d(Layer): |
no outgoing calls
no test coverage detected
searching dependent graphs…