MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / GlobalAvgPooling

Function GlobalAvgPooling

tensorpack/models/pool.py:56–70  ·  view source on GitHub ↗

Global average pooling as in the paper `Network In Network `_. Args: x (tf.Tensor): a 4D tensor. Returns: tf.Tensor: a NC tensor named ``output``.

(x, data_format='channels_last')

Source from the content-addressed store, hash-verified

54
55@layer_register(log_shape=True)
56def GlobalAvgPooling(x, data_format='channels_last'):
57 """
58 Global average pooling as in the paper `Network In Network
59 <http://arxiv.org/abs/1312.4400>`_.
60
61 Args:
62 x (tf.Tensor): a 4D tensor.
63
64 Returns:
65 tf.Tensor: a NC tensor named ``output``.
66 """
67 assert x.shape.ndims == 4
68 data_format = get_data_format(data_format)
69 axis = [1, 2] if data_format == 'channels_last' else [2, 3]
70 return tf.reduce_mean(x, axis, name='output')
71
72
73def UnPooling2x2ZeroFilled(x):

Callers 10

se_bottleneckFunction · 0.90
resnet_backboneFunction · 0.90
se_bottleneckFunction · 0.90
resnet_backboneFunction · 0.90
roi_headsMethod · 0.90
build_graphMethod · 0.85
build_graphMethod · 0.85
get_logitsMethod · 0.85
build_graphMethod · 0.85
get_logitsMethod · 0.85

Calls 1

get_data_formatFunction · 0.85

Tested by

no test coverage detected