(modules)
| 6 | from packaging import version |
| 7 | |
| 8 | def init_weights(modules): |
| 9 | for m in modules: |
| 10 | if isinstance(m, nn.Conv2d): |
| 11 | init.xavier_uniform_(m.weight.data) |
| 12 | if m.bias is not None: |
| 13 | m.bias.data.zero_() |
| 14 | elif isinstance(m, nn.BatchNorm2d): |
| 15 | m.weight.data.fill_(1) |
| 16 | m.bias.data.zero_() |
| 17 | elif isinstance(m, nn.Linear): |
| 18 | m.weight.data.normal_(0, 0.01) |
| 19 | m.bias.data.zero_() |
| 20 | |
| 21 | |
| 22 | class vgg16_bn(torch.nn.Module): |