MCPcopy
hub / github.com/alibaba/EasyCV / PNet

Class PNet

easycv/thirdparty/mtcnn/get_nets.py:52–101  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

50
51
52class PNet(nn.Module):
53 def __init__(self, dir_path=None):
54
55 super(PNet, self).__init__()
56
57 # suppose we have input with size HxW, then
58 # after first layer: H - 2,
59 # after pool: ceil((H - 2)/2),
60 # after second conv: ceil((H - 2)/2) - 2,
61 # after last conv: ceil((H - 2)/2) - 4,
62 # and the same for W
63
64 self.features = nn.Sequential(
65 OrderedDict(
66 [
67 ("conv1", nn.Conv2d(3, 10, 3, 1)),
68 ("prelu1", nn.PReLU(10)),
69 ("pool1", nn.MaxPool2d(2, 2, ceil_mode=True)),
70 ("conv2", nn.Conv2d(10, 16, 3, 1)),
71 ("prelu2", nn.PReLU(16)),
72 ("conv3", nn.Conv2d(16, 32, 3, 1)),
73 ("prelu3", nn.PReLU(32)),
74 ]
75 )
76 )
77
78 self.conv4_1 = nn.Conv2d(32, 2, 1, 1)
79 self.conv4_2 = nn.Conv2d(32, 4, 1, 1)
80
81 if dir_path is None:
82 dir_path = path.dirname(__file__)
83
84 weights = get_url_weights("weights/pnet.npy", dir_path)
85
86 for n, p in self.named_parameters():
87 p.data = torch.FloatTensor(weights[n])
88
89 def forward(self, x):
90 """
91 Arguments:
92 x: a float tensor with shape [batch_size, 3, h, w].
93 Returns:
94 b: a float tensor with shape [batch_size, 4, h', w'].
95 a: a float tensor with shape [batch_size, 2, h', w'].
96 """
97 x = self.features(x)
98 a = self.conv4_1(x)
99 b = self.conv4_2(x)
100 a = F.softmax(a, dim=1)
101 return b, a
102
103
104class RNet(nn.Module):

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected