(self, conv)
| 180 | nn.init.zeros_(conv.bias.data) |
| 181 | |
| 182 | def init_weight2(self, conv): |
| 183 | conv_weight = conv.weight.data.detach().clone() |
| 184 | nn.init.zeros_(conv_weight) |
| 185 | c1, c2, t, h, w = conv_weight.size() |
| 186 | init_matrix = torch.eye(c1 // 2, c2) |
| 187 | conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix |
| 188 | conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix |
| 189 | conv.weight = nn.Parameter(conv_weight) |
| 190 | nn.init.zeros_(conv.bias.data) |
| 191 | |
| 192 | |
| 193 | class ResidualBlock(nn.Module): |