MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / FixedUnPooling

Function FixedUnPooling

tensorpack/models/pool.py:89–138  ·  view source on GitHub ↗

Unpool the input with a fixed matrix to perform kronecker product with. Args: x (tf.Tensor): a 4D image tensor shape: int or (h, w) tuple unpool_mat: a tf.Tensor or np.ndarray 2D matrix with size=shape. If is None, will use a matrix with 1 at top-left co

(x, shape, unpool_mat=None, data_format='channels_last')

Source from the content-addressed store, hash-verified

87
88@layer_register(log_shape=True)
89def FixedUnPooling(x, shape, unpool_mat=None, data_format='channels_last'):
90 """
91 Unpool the input with a fixed matrix to perform kronecker product with.
92
93 Args:
94 x (tf.Tensor): a 4D image tensor
95 shape: int or (h, w) tuple
96 unpool_mat: a tf.Tensor or np.ndarray 2D matrix with size=shape.
97 If is None, will use a matrix with 1 at top-left corner.
98
99 Returns:
100 tf.Tensor: a 4D image tensor.
101 """
102 data_format = get_data_format(data_format, keras_mode=False)
103 shape = shape2d(shape)
104
105 output_shape = StaticDynamicShape(x)
106 output_shape.apply(1 if data_format == 'NHWC' else 2, lambda x: x * shape[0])
107 output_shape.apply(2 if data_format == 'NHWC' else 3, lambda x: x * shape[1])
108
109 # a faster implementation for this special case
110 if shape[0] == 2 and shape[1] == 2 and unpool_mat is None and data_format == 'NHWC':
111 ret = UnPooling2x2ZeroFilled(x)
112 else:
113 # check unpool_mat
114 if unpool_mat is None:
115 mat = np.zeros(shape, dtype='float32')
116 mat[0][0] = 1
117 unpool_mat = tf.constant(mat, name='unpool_mat')
118 elif isinstance(unpool_mat, np.ndarray):
119 unpool_mat = tf.constant(unpool_mat, name='unpool_mat')
120 assert unpool_mat.shape.as_list() == list(shape)
121
122 if data_format == 'NHWC':
123 x = tf.transpose(x, [0, 3, 1, 2])
124 # perform a tensor-matrix kronecker product
125 x = tf.expand_dims(x, -1) # bchwx1
126 mat = tf.expand_dims(unpool_mat, 0) # 1xshxsw
127 ret = tf.tensordot(x, mat, axes=1) # bxcxhxwxshxsw
128
129 if data_format == 'NHWC':
130 ret = tf.transpose(ret, [0, 2, 4, 3, 5, 1])
131 else:
132 ret = tf.transpose(ret, [0, 1, 2, 4, 3, 5])
133
134 shape3_dyn = [output_shape.get_dynamic(k) for k in range(1, 4)]
135 ret = tf.reshape(ret, tf.stack([-1] + shape3_dyn))
136
137 ret.set_shape(tf.TensorShape(output_shape.get_static()))
138 return ret

Callers 2

upsample2xFunction · 0.90
test_FixedUnPoolingMethod · 0.85

Calls 7

applyMethod · 0.95
get_dynamicMethod · 0.95
get_staticMethod · 0.95
get_data_formatFunction · 0.85
shape2dFunction · 0.85
StaticDynamicShapeClass · 0.85
UnPooling2x2ZeroFilledFunction · 0.85

Tested by 1

test_FixedUnPoolingMethod · 0.68