(self, params, hparams, conv_1x1_pad)
| 577 | |
| 578 | class TorchWavenetModule(nn.Module): |
| 579 | def __init__(self, params, hparams, conv_1x1_pad): |
| 580 | super(TorchWavenetModule, self).__init__() |
| 581 | self.conv_dilation = TorchCausalConv1d( |
| 582 | in_channels=hparams["components"]["conv_dilation"]["in_ch"], |
| 583 | out_channels=hparams["components"]["conv_dilation"]["out_ch"], |
| 584 | kernel_size=hparams["components"]["conv_dilation"]["kernel_width"], |
| 585 | stride=hparams["components"]["conv_dilation"]["stride"], |
| 586 | dilation=hparams["components"]["conv_dilation"]["dilation"] + 1, |
| 587 | bias=True, |
| 588 | ) |
| 589 | |
| 590 | self.conv_1x1 = nn.Conv1d( |
| 591 | in_channels=hparams["components"]["conv_1x1"]["in_ch"], |
| 592 | out_channels=hparams["components"]["conv_1x1"]["out_ch"], |
| 593 | kernel_size=hparams["components"]["conv_1x1"]["kernel_width"], |
| 594 | stride=hparams["components"]["conv_1x1"]["stride"], |
| 595 | padding=conv_1x1_pad, |
| 596 | dilation=hparams["components"]["conv_1x1"]["dilation"] + 1, |
| 597 | bias=True, |
| 598 | ) |
| 599 | |
| 600 | W = params["components"]["conv_dilation"]["W"] |
| 601 | b = params["components"]["conv_dilation"]["b"] |
| 602 | # (f[0], n_in, n_out) -> (n_out, n_in, f[0]) |
| 603 | W = np.moveaxis(W, [0, 1, 2], [-1, -2, -3]) |
| 604 | self.conv_dilation.weight = nn.Parameter(torch.FloatTensor(W)) |
| 605 | self.conv_dilation.bias = nn.Parameter(torch.FloatTensor(b.flatten())) |
| 606 | assert self.conv_dilation.weight.shape == W.shape |
| 607 | assert self.conv_dilation.bias.shape == b.flatten().shape |
| 608 | |
| 609 | W = params["components"]["conv_1x1"]["W"] |
| 610 | b = params["components"]["conv_1x1"]["b"] |
| 611 | # (f[0], n_in, n_out) -> (n_out, n_in, f[0]) |
| 612 | W = np.moveaxis(W, [0, 1, 2], [-1, -2, -3]) |
| 613 | self.conv_1x1.weight = nn.Parameter(torch.FloatTensor(W)) |
| 614 | self.conv_1x1.bias = nn.Parameter(torch.FloatTensor(b.flatten())) |
| 615 | assert self.conv_1x1.weight.shape == W.shape |
| 616 | assert self.conv_1x1.bias.shape == b.flatten().shape |
| 617 | |
| 618 | def forward(self, X_main, X_skip): |
| 619 | # (N, W, C) -> (N, C, W) |
nothing calls this directly
no test coverage detected