MCPcopy Index your code
hub / github.com/NVIDIA/semantic-segmentation / get_trunk

Function get_trunk

network/utils.py:102–141  ·  view source on GitHub ↗

Retrieve the network trunk and channel counts.

(trunk_name, output_stride=8)

Source from the content-addressed store, hash-verified

100
101
102def get_trunk(trunk_name, output_stride=8):
103 """
104 Retrieve the network trunk and channel counts.
105 """
106 assert output_stride == 8, 'Only stride8 supported right now'
107
108 if trunk_name == 'wrn38':
109 #
110 # FIXME: pass in output_stride once we support stride 16
111 #
112 backbone = wrn38(pretrained=True)
113 s2_ch = 128
114 s4_ch = 256
115 high_level_ch = 4096
116 elif trunk_name == 'xception71':
117 backbone = xception71(output_stride=output_stride, BatchNorm=Norm2d,
118 pretrained=True)
119 s2_ch = 64
120 s4_ch = 128
121 high_level_ch = 2048
122 elif trunk_name == 'seresnext-50' or trunk_name == 'seresnext-101':
123 backbone = get_resnet(trunk_name, output_stride=output_stride)
124 s2_ch = 48
125 s4_ch = -1
126 high_level_ch = 2048
127 elif trunk_name == 'resnet-50' or trunk_name == 'resnet-101':
128 backbone = get_resnet(trunk_name, output_stride=output_stride)
129 s2_ch = 256
130 s4_ch = -1
131 high_level_ch = 2048
132 elif trunk_name == 'hrnetv2':
133 backbone = hrnetv2.get_seg_model()
134 high_level_ch = backbone.high_level_ch
135 s2_ch = -1
136 s4_ch = -1
137 else:
138 raise 'unknown backbone {}'.format(trunk_name)
139
140 logx.msg("Trunk: {}".format(trunk_name))
141 return backbone, s2_ch, s4_ch, high_level_ch
142
143
144class ConvBnRelu(nn.Module):

Callers 15

__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90

Calls 3

wrn38Class · 0.90
xception71Class · 0.90
get_resnetClass · 0.85

Tested by

no test coverage detected