| 318 | Hourglass Encoder |
| 319 | """ |
| 320 | def __init__(self, |
| 321 | block_expansion, |
| 322 | in_features, |
| 323 | num_blocks=3, |
| 324 | max_features=256, |
| 325 | mobile_net=False): |
| 326 | super(Encoder, self).__init__() |
| 327 | |
| 328 | down_blocks = [] |
| 329 | for i in range(num_blocks): |
| 330 | if mobile_net: |
| 331 | down_blocks.append( |
| 332 | MobileDownBlock2d(in_features if i == 0 else min( |
| 333 | max_features, block_expansion * (2**i)), |
| 334 | min(max_features, |
| 335 | block_expansion * (2**(i + 1))), |
| 336 | kernel_size=3, |
| 337 | padding=1)) |
| 338 | else: |
| 339 | down_blocks.append( |
| 340 | DownBlock2d(in_features if i == 0 else min( |
| 341 | max_features, block_expansion * (2**i)), |
| 342 | min(max_features, |
| 343 | block_expansion * (2**(i + 1))), |
| 344 | kernel_size=3, |
| 345 | padding=1)) |
| 346 | self.down_blocks = nn.LayerList(down_blocks) |
| 347 | |
| 348 | def forward(self, x): |
| 349 | outs = [x] |