(self, conv)
| 168 | return x |
| 169 | |
| 170 | def init_weight(self, conv): |
| 171 | conv_weight = conv.weight.detach().clone() |
| 172 | nn.init.zeros_(conv_weight) |
| 173 | c1, c2, t, h, w = conv_weight.size() |
| 174 | one_matrix = torch.eye(c1, c2) |
| 175 | init_matrix = one_matrix |
| 176 | nn.init.zeros_(conv_weight) |
| 177 | conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 |
| 178 | conv.weight = nn.Parameter(conv_weight) |
| 179 | nn.init.zeros_(conv.bias.data) |
| 180 | |
| 181 | def init_weight2(self, conv): |
| 182 | conv_weight = conv.weight.data.detach().clone() |