| 75 | class WideResNet(nn.Module): |
| 76 | |
| 77 | def __init__(self, |
| 78 | block, |
| 79 | layers, |
| 80 | sample_size, |
| 81 | sample_duration, |
| 82 | k=1, |
| 83 | shortcut_type='B', |
| 84 | num_classes=400): |
| 85 | self.inplanes = 64 |
| 86 | super(WideResNet, self).__init__() |
| 87 | self.conv1 = nn.Conv3d( |
| 88 | 3, |
| 89 | 64, |
| 90 | kernel_size=7, |
| 91 | stride=(1, 2, 2), |
| 92 | padding=(3, 3, 3), |
| 93 | bias=False) |
| 94 | self.bn1 = nn.BatchNorm3d(64) |
| 95 | self.relu = nn.ReLU(inplace=True) |
| 96 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) |
| 97 | self.layer1 = self._make_layer(block, 64 * k, layers[0], shortcut_type) |
| 98 | self.layer2 = self._make_layer( |
| 99 | block, 128 * k, layers[1], shortcut_type, stride=2) |
| 100 | self.layer3 = self._make_layer( |
| 101 | block, 256 * k, layers[2], shortcut_type, stride=2) |
| 102 | self.layer4 = self._make_layer( |
| 103 | block, 512 * k, layers[3], shortcut_type, stride=2) |
| 104 | last_duration = int(math.ceil(sample_duration / 16)) |
| 105 | last_size = int(math.ceil(sample_size / 32)) |
| 106 | self.avgpool = nn.AvgPool3d( |
| 107 | (last_duration, last_size, last_size), stride=1) |
| 108 | self.fc = nn.Linear(512 * k * block.expansion, num_classes) |
| 109 | |
| 110 | for m in self.modules(): |
| 111 | if isinstance(m, nn.Conv3d): |
| 112 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') |
| 113 | elif isinstance(m, nn.BatchNorm3d): |
| 114 | m.weight.data.fill_(1) |
| 115 | m.bias.data.zero_() |
| 116 | |
| 117 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): |
| 118 | downsample = None |