Retrieve the network trunk and channel counts.
(trunk_name, output_stride=8)
| 100 | |
| 101 | |
| 102 | def get_trunk(trunk_name, output_stride=8): |
| 103 | """ |
| 104 | Retrieve the network trunk and channel counts. |
| 105 | """ |
| 106 | assert output_stride == 8, 'Only stride8 supported right now' |
| 107 | |
| 108 | if trunk_name == 'wrn38': |
| 109 | # |
| 110 | # FIXME: pass in output_stride once we support stride 16 |
| 111 | # |
| 112 | backbone = wrn38(pretrained=True) |
| 113 | s2_ch = 128 |
| 114 | s4_ch = 256 |
| 115 | high_level_ch = 4096 |
| 116 | elif trunk_name == 'xception71': |
| 117 | backbone = xception71(output_stride=output_stride, BatchNorm=Norm2d, |
| 118 | pretrained=True) |
| 119 | s2_ch = 64 |
| 120 | s4_ch = 128 |
| 121 | high_level_ch = 2048 |
| 122 | elif trunk_name == 'seresnext-50' or trunk_name == 'seresnext-101': |
| 123 | backbone = get_resnet(trunk_name, output_stride=output_stride) |
| 124 | s2_ch = 48 |
| 125 | s4_ch = -1 |
| 126 | high_level_ch = 2048 |
| 127 | elif trunk_name == 'resnet-50' or trunk_name == 'resnet-101': |
| 128 | backbone = get_resnet(trunk_name, output_stride=output_stride) |
| 129 | s2_ch = 256 |
| 130 | s4_ch = -1 |
| 131 | high_level_ch = 2048 |
| 132 | elif trunk_name == 'hrnetv2': |
| 133 | backbone = hrnetv2.get_seg_model() |
| 134 | high_level_ch = backbone.high_level_ch |
| 135 | s2_ch = -1 |
| 136 | s4_ch = -1 |
| 137 | else: |
| 138 | raise 'unknown backbone {}'.format(trunk_name) |
| 139 | |
| 140 | logx.msg("Trunk: {}".format(trunk_name)) |
| 141 | return backbone, s2_ch, s4_ch, high_level_ch |
| 142 | |
| 143 | |
| 144 | class ConvBnRelu(nn.Module): |