MCPcopy Index your code
hub / github.com/ddbourgin/numpy-ml / TorchSkipConnectionIdentity

Class TorchSkipConnectionIdentity

numpy_ml/tests/nn_torch_models.py:393–538  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

391
392
393class TorchSkipConnectionIdentity(nn.Module):
394 def __init__(self, act_fn, pad1, pad2, params, hparams, momentum=0.9, epsilon=1e-5):
395 super(TorchSkipConnectionIdentity, self).__init__()
396
397 self.conv1 = nn.Conv2d(
398 hparams["in_ch"],
399 hparams["out_ch"],
400 hparams["kernel_shape1"],
401 padding=pad1,
402 stride=hparams["stride1"],
403 bias=True,
404 )
405
406 self.act_fn = act_fn
407
408 self.batchnorm1 = nn.BatchNorm2d(
409 num_features=hparams["out_ch"],
410 momentum=1 - momentum,
411 eps=epsilon,
412 affine=True,
413 )
414
415 self.conv2 = nn.Conv2d(
416 hparams["out_ch"],
417 hparams["out_ch"],
418 hparams["kernel_shape2"],
419 padding=pad2,
420 stride=hparams["stride2"],
421 bias=True,
422 )
423
424 self.batchnorm2 = nn.BatchNorm2d(
425 num_features=hparams["out_ch"],
426 momentum=1 - momentum,
427 eps=epsilon,
428 affine=True,
429 )
430
431 orig, W_swap = [0, 1, 2, 3], [-2, -1, -3, -4]
432 # (f[0], f[1], n_in, n_out) -> (n_out, n_in, f[0], f[1])
433 W = params["components"]["conv1"]["W"]
434 b = params["components"]["conv1"]["b"]
435 W = np.moveaxis(W, orig, W_swap)
436 assert self.conv1.weight.shape == W.shape
437 assert self.conv1.bias.shape == b.flatten().shape
438 self.conv1.weight = nn.Parameter(torch.FloatTensor(W))
439 self.conv1.bias = nn.Parameter(torch.FloatTensor(b.flatten()))
440
441 scaler = params["components"]["batchnorm1"]["scaler"]
442 intercept = params["components"]["batchnorm1"]["intercept"]
443 self.batchnorm1.weight = nn.Parameter(torch.FloatTensor(scaler))
444 self.batchnorm1.bias = nn.Parameter(torch.FloatTensor(intercept))
445
446 # (f[0], f[1], n_in, n_out) -> (n_out, n_in, f[0], f[1])
447 W = params["components"]["conv2"]["W"]
448 b = params["components"]["conv2"]["b"]
449 W = np.moveaxis(W, orig, W_swap)
450 assert self.conv2.weight.shape == W.shape

Callers 1

Calls

no outgoing calls

Tested by 1