Network for monocular depth estimation.
| 10 | |
| 11 | |
| 12 | class MidasNet(BaseModel): |
| 13 | """Network for monocular depth estimation. |
| 14 | """ |
| 15 | |
| 16 | def __init__(self, path=None, features=256, non_negative=True): |
| 17 | """Init. |
| 18 | |
| 19 | Args: |
| 20 | path (str, optional): Path to saved model. Defaults to None. |
| 21 | features (int, optional): Number of features. Defaults to 256. |
| 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 |
| 23 | """ |
| 24 | print("Loading weights: ", path) |
| 25 | |
| 26 | super(MidasNet, self).__init__() |
| 27 | |
| 28 | use_pretrained = False if path is None else True |
| 29 | |
| 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) |
| 31 | |
| 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) |
| 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) |
| 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) |
| 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) |
| 36 | |
| 37 | self.scratch.output_conv = nn.Sequential( |
| 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), |
| 39 | Interpolate(scale_factor=2, mode="bilinear"), |
| 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), |
| 41 | nn.ReLU(True), |
| 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), |
| 43 | nn.ReLU(True) if non_negative else nn.Identity(), |
| 44 | ) |
| 45 | |
| 46 | if path: |
| 47 | self.load(path) |
| 48 | |
| 49 | def forward(self, x): |
| 50 | """Forward pass. |
| 51 | |
| 52 | Args: |
| 53 | x (tensor): input data (image) |
| 54 | |
| 55 | Returns: |
| 56 | tensor: depth |
| 57 | """ |
| 58 | |
| 59 | layer_1 = self.pretrained.layer1(x) |
| 60 | layer_2 = self.pretrained.layer2(layer_1) |
| 61 | layer_3 = self.pretrained.layer3(layer_2) |
| 62 | layer_4 = self.pretrained.layer4(layer_3) |
| 63 | |
| 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) |
| 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) |
| 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) |
| 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) |
| 68 | |
| 69 | path_4 = self.scratch.refinenet4(layer_4_rn) |
no outgoing calls
no test coverage detected