Returns: label_logits: fHxfWxNA box_logits: fHxfWxNAx4
(featuremap, channel, num_anchors)
| 16 | @layer_register(log_shape=True) |
| 17 | @auto_reuse_variable_scope |
| 18 | def rpn_head(featuremap, channel, num_anchors): |
| 19 | """ |
| 20 | Returns: |
| 21 | label_logits: fHxfWxNA |
| 22 | box_logits: fHxfWxNAx4 |
| 23 | """ |
| 24 | with argscope(Conv2D, data_format='channels_first', |
| 25 | kernel_initializer=tf.random_normal_initializer(stddev=0.01)): |
| 26 | hidden = Conv2D('conv0', featuremap, channel, 3, activation=tf.nn.relu) |
| 27 | |
| 28 | label_logits = Conv2D('class', hidden, num_anchors, 1) |
| 29 | box_logits = Conv2D('box', hidden, 4 * num_anchors, 1) |
| 30 | # 1, NA(*4), im/16, im/16 (NCHW) |
| 31 | |
| 32 | label_logits = tf.transpose(label_logits, [0, 2, 3, 1]) # 1xfHxfWxNA |
| 33 | label_logits = tf.squeeze(label_logits, 0) # fHxfWxNA |
| 34 | |
| 35 | shp = tf.shape(box_logits) # 1x(NAx4)xfHxfW |
| 36 | box_logits = tf.transpose(box_logits, [0, 2, 3, 1]) # 1xfHxfWx(NAx4) |
| 37 | box_logits = tf.reshape(box_logits, tf.stack([shp[2], shp[3], num_anchors, 4])) # fHxfWxNAx4 |
| 38 | return label_logits, box_logits |
| 39 | |
| 40 | |
| 41 | @under_name_scope() |