MCPcopy
hub / github.com/thygate/stable-diffusion-webui-depthmap-script / MidasNet

Class MidasNet

dmidas/midas_net.py:12–76  ·  view source on GitHub ↗

Network for monocular depth estimation.

Source from the content-addressed store, hash-verified

10
11
12class MidasNet(BaseModel):
13 """Network for monocular depth estimation.
14 """
15
16 def __init__(self, path=None, features=256, non_negative=True):
17 """Init.
18
19 Args:
20 path (str, optional): Path to saved model. Defaults to None.
21 features (int, optional): Number of features. Defaults to 256.
22 backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 """
24 print("Loading weights: ", path)
25
26 super(MidasNet, self).__init__()
27
28 use_pretrained = False if path is None else True
29
30 self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
32 self.scratch.refinenet4 = FeatureFusionBlock(features)
33 self.scratch.refinenet3 = FeatureFusionBlock(features)
34 self.scratch.refinenet2 = FeatureFusionBlock(features)
35 self.scratch.refinenet1 = FeatureFusionBlock(features)
36
37 self.scratch.output_conv = nn.Sequential(
38 nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 Interpolate(scale_factor=2, mode="bilinear"),
40 nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 nn.ReLU(True),
42 nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 nn.ReLU(True) if non_negative else nn.Identity(),
44 )
45
46 if path:
47 self.load(path)
48
49 def forward(self, x):
50 """Forward pass.
51
52 Args:
53 x (tensor): input data (image)
54
55 Returns:
56 tensor: depth
57 """
58
59 layer_1 = self.pretrained.layer1(x)
60 layer_2 = self.pretrained.layer2(layer_1)
61 layer_3 = self.pretrained.layer3(layer_2)
62 layer_4 = self.pretrained.layer4(layer_3)
63
64 layer_1_rn = self.scratch.layer1_rn(layer_1)
65 layer_2_rn = self.scratch.layer2_rn(layer_2)
66 layer_3_rn = self.scratch.layer3_rn(layer_3)
67 layer_4_rn = self.scratch.layer4_rn(layer_4)
68
69 path_4 = self.scratch.refinenet4(layer_4_rn)

Callers 2

load_modelFunction · 0.90
load_modelsMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected