(l, n_out, stride, activation=tf.identity)
| 8 | |
| 9 | |
| 10 | def resnet_shortcut(l, n_out, stride, activation=tf.identity): |
| 11 | data_format = get_arg_scope()['Conv2D']['data_format'] |
| 12 | n_in = l.get_shape().as_list()[1 if data_format in ['NCHW', 'channels_first'] else 3] |
| 13 | if n_in != n_out: # change dimension when channel is not the same |
| 14 | return Conv2D('convshortcut', l, n_out, 1, strides=stride, activation=activation) |
| 15 | else: |
| 16 | return l |
| 17 | |
| 18 | |
| 19 | def get_bn(zero_init=False): |
no test coverage detected