Init. Args: path (str, optional): Path to saved model. Defaults to None. features (int, optional): Number of features. Defaults to 256. backbone (str, optional): Backbone network for encoder. Defaults to resnet50
(self, path=None, features=256, non_negative=True)
| 14 | """ |
| 15 | |
| 16 | def __init__(self, path=None, features=256, non_negative=True): |
| 17 | """Init. |
| 18 | |
| 19 | Args: |
| 20 | path (str, optional): Path to saved model. Defaults to None. |
| 21 | features (int, optional): Number of features. Defaults to 256. |
| 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 |
| 23 | """ |
| 24 | print("Loading weights: ", path) |
| 25 | |
| 26 | super(MidasNet, self).__init__() |
| 27 | |
| 28 | use_pretrained = False if path is None else True |
| 29 | |
| 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) |
| 31 | |
| 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) |
| 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) |
| 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) |
| 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) |
| 36 | |
| 37 | self.scratch.output_conv = nn.Sequential( |
| 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), |
| 39 | Interpolate(scale_factor=2, mode="bilinear"), |
| 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), |
| 41 | nn.ReLU(True), |
| 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), |
| 43 | nn.ReLU(True) if non_negative else nn.Identity(), |
| 44 | ) |
| 45 | |
| 46 | if path: |
| 47 | self.load(path) |
| 48 | |
| 49 | def forward(self, x): |
| 50 | """Forward pass. |
nothing calls this directly
no test coverage detected