| 391 | |
| 392 | |
| 393 | class TorchSkipConnectionIdentity(nn.Module): |
| 394 | def __init__(self, act_fn, pad1, pad2, params, hparams, momentum=0.9, epsilon=1e-5): |
| 395 | super(TorchSkipConnectionIdentity, self).__init__() |
| 396 | |
| 397 | self.conv1 = nn.Conv2d( |
| 398 | hparams["in_ch"], |
| 399 | hparams["out_ch"], |
| 400 | hparams["kernel_shape1"], |
| 401 | padding=pad1, |
| 402 | stride=hparams["stride1"], |
| 403 | bias=True, |
| 404 | ) |
| 405 | |
| 406 | self.act_fn = act_fn |
| 407 | |
| 408 | self.batchnorm1 = nn.BatchNorm2d( |
| 409 | num_features=hparams["out_ch"], |
| 410 | momentum=1 - momentum, |
| 411 | eps=epsilon, |
| 412 | affine=True, |
| 413 | ) |
| 414 | |
| 415 | self.conv2 = nn.Conv2d( |
| 416 | hparams["out_ch"], |
| 417 | hparams["out_ch"], |
| 418 | hparams["kernel_shape2"], |
| 419 | padding=pad2, |
| 420 | stride=hparams["stride2"], |
| 421 | bias=True, |
| 422 | ) |
| 423 | |
| 424 | self.batchnorm2 = nn.BatchNorm2d( |
| 425 | num_features=hparams["out_ch"], |
| 426 | momentum=1 - momentum, |
| 427 | eps=epsilon, |
| 428 | affine=True, |
| 429 | ) |
| 430 | |
| 431 | orig, W_swap = [0, 1, 2, 3], [-2, -1, -3, -4] |
| 432 | # (f[0], f[1], n_in, n_out) -> (n_out, n_in, f[0], f[1]) |
| 433 | W = params["components"]["conv1"]["W"] |
| 434 | b = params["components"]["conv1"]["b"] |
| 435 | W = np.moveaxis(W, orig, W_swap) |
| 436 | assert self.conv1.weight.shape == W.shape |
| 437 | assert self.conv1.bias.shape == b.flatten().shape |
| 438 | self.conv1.weight = nn.Parameter(torch.FloatTensor(W)) |
| 439 | self.conv1.bias = nn.Parameter(torch.FloatTensor(b.flatten())) |
| 440 | |
| 441 | scaler = params["components"]["batchnorm1"]["scaler"] |
| 442 | intercept = params["components"]["batchnorm1"]["intercept"] |
| 443 | self.batchnorm1.weight = nn.Parameter(torch.FloatTensor(scaler)) |
| 444 | self.batchnorm1.bias = nn.Parameter(torch.FloatTensor(intercept)) |
| 445 | |
| 446 | # (f[0], f[1], n_in, n_out) -> (n_out, n_in, f[0], f[1]) |
| 447 | W = params["components"]["conv2"]["W"] |
| 448 | b = params["components"]["conv2"]["b"] |
| 449 | W = np.moveaxis(W, orig, W_swap) |
| 450 | assert self.conv2.weight.shape == W.shape |
no outgoing calls