MCPcopy Index your code
hub / github.com/ZHKKKe/MODNet / LRBranch

Class LRBranch

torchscript/modnet_torchscript.py:92–117  ·  view source on GitHub ↗

Low Resolution Branch of MODNet

Source from the content-addressed store, hash-verified

90#------------------------------------------------------------------------------
91
92class LRBranch(nn.Module):
93 """ Low Resolution Branch of MODNet
94 """
95
96 def __init__(self, backbone):
97 super(LRBranch, self).__init__()
98
99 enc_channels = backbone.enc_channels
100
101 self.backbone = backbone
102 self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
103 self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
104 self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
105 self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)
106
107 def forward(self, img):
108 enc_features = self.backbone.forward(img)
109 enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
110
111 enc32x = self.se_block(enc32x)
112 lr16x = F.interpolate(enc32x, scale_factor=2.0, mode='bilinear', align_corners=False)
113 lr16x = self.conv_lr16x(lr16x)
114 lr8x = F.interpolate(lr16x, scale_factor=2.0, mode='bilinear', align_corners=False)
115 lr8x = self.conv_lr8x(lr8x)
116
117 return lr8x, enc2x, enc4x
118
119
120class HRBranch(nn.Module):

Callers 1

__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected