MCPcopy
hub / github.com/ddbourgin/numpy-ml / __init__

Method __init__

numpy_ml/tests/nn_torch_models.py:579–616  ·  view source on GitHub ↗
(self, params, hparams, conv_1x1_pad)

Source from the content-addressed store, hash-verified

577
578class TorchWavenetModule(nn.Module):
579 def __init__(self, params, hparams, conv_1x1_pad):
580 super(TorchWavenetModule, self).__init__()
581 self.conv_dilation = TorchCausalConv1d(
582 in_channels=hparams["components"]["conv_dilation"]["in_ch"],
583 out_channels=hparams["components"]["conv_dilation"]["out_ch"],
584 kernel_size=hparams["components"]["conv_dilation"]["kernel_width"],
585 stride=hparams["components"]["conv_dilation"]["stride"],
586 dilation=hparams["components"]["conv_dilation"]["dilation"] + 1,
587 bias=True,
588 )
589
590 self.conv_1x1 = nn.Conv1d(
591 in_channels=hparams["components"]["conv_1x1"]["in_ch"],
592 out_channels=hparams["components"]["conv_1x1"]["out_ch"],
593 kernel_size=hparams["components"]["conv_1x1"]["kernel_width"],
594 stride=hparams["components"]["conv_1x1"]["stride"],
595 padding=conv_1x1_pad,
596 dilation=hparams["components"]["conv_1x1"]["dilation"] + 1,
597 bias=True,
598 )
599
600 W = params["components"]["conv_dilation"]["W"]
601 b = params["components"]["conv_dilation"]["b"]
602 # (f[0], n_in, n_out) -> (n_out, n_in, f[0])
603 W = np.moveaxis(W, [0, 1, 2], [-1, -2, -3])
604 self.conv_dilation.weight = nn.Parameter(torch.FloatTensor(W))
605 self.conv_dilation.bias = nn.Parameter(torch.FloatTensor(b.flatten()))
606 assert self.conv_dilation.weight.shape == W.shape
607 assert self.conv_dilation.bias.shape == b.flatten().shape
608
609 W = params["components"]["conv_1x1"]["W"]
610 b = params["components"]["conv_1x1"]["b"]
611 # (f[0], n_in, n_out) -> (n_out, n_in, f[0])
612 W = np.moveaxis(W, [0, 1, 2], [-1, -2, -3])
613 self.conv_1x1.weight = nn.Parameter(torch.FloatTensor(W))
614 self.conv_1x1.bias = nn.Parameter(torch.FloatTensor(b.flatten()))
615 assert self.conv_1x1.weight.shape == W.shape
616 assert self.conv_1x1.bias.shape == b.flatten().shape
617
618 def forward(self, X_main, X_skip):
619 # (N, W, C) -> (N, C, W)

Callers

nothing calls this directly

Calls 2

TorchCausalConv1dClass · 0.85
__init__Method · 0.45

Tested by

no test coverage detected