MCPcopy
hub / github.com/ddbourgin/numpy-ml / TorchWavenetModule

Class TorchWavenetModule

numpy_ml/tests/nn_torch_models.py:578–712  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

576
577
578class TorchWavenetModule(nn.Module):
579 def __init__(self, params, hparams, conv_1x1_pad):
580 super(TorchWavenetModule, self).__init__()
581 self.conv_dilation = TorchCausalConv1d(
582 in_channels=hparams["components"]["conv_dilation"]["in_ch"],
583 out_channels=hparams["components"]["conv_dilation"]["out_ch"],
584 kernel_size=hparams["components"]["conv_dilation"]["kernel_width"],
585 stride=hparams["components"]["conv_dilation"]["stride"],
586 dilation=hparams["components"]["conv_dilation"]["dilation"] + 1,
587 bias=True,
588 )
589
590 self.conv_1x1 = nn.Conv1d(
591 in_channels=hparams["components"]["conv_1x1"]["in_ch"],
592 out_channels=hparams["components"]["conv_1x1"]["out_ch"],
593 kernel_size=hparams["components"]["conv_1x1"]["kernel_width"],
594 stride=hparams["components"]["conv_1x1"]["stride"],
595 padding=conv_1x1_pad,
596 dilation=hparams["components"]["conv_1x1"]["dilation"] + 1,
597 bias=True,
598 )
599
600 W = params["components"]["conv_dilation"]["W"]
601 b = params["components"]["conv_dilation"]["b"]
602 # (f[0], n_in, n_out) -> (n_out, n_in, f[0])
603 W = np.moveaxis(W, [0, 1, 2], [-1, -2, -3])
604 self.conv_dilation.weight = nn.Parameter(torch.FloatTensor(W))
605 self.conv_dilation.bias = nn.Parameter(torch.FloatTensor(b.flatten()))
606 assert self.conv_dilation.weight.shape == W.shape
607 assert self.conv_dilation.bias.shape == b.flatten().shape
608
609 W = params["components"]["conv_1x1"]["W"]
610 b = params["components"]["conv_1x1"]["b"]
611 # (f[0], n_in, n_out) -> (n_out, n_in, f[0])
612 W = np.moveaxis(W, [0, 1, 2], [-1, -2, -3])
613 self.conv_1x1.weight = nn.Parameter(torch.FloatTensor(W))
614 self.conv_1x1.bias = nn.Parameter(torch.FloatTensor(b.flatten()))
615 assert self.conv_1x1.weight.shape == W.shape
616 assert self.conv_1x1.bias.shape == b.flatten().shape
617
618 def forward(self, X_main, X_skip):
619 # (N, W, C) -> (N, C, W)
620 self.X_main = np.moveaxis(X_main, [0, 1, 2], [0, -1, -2])
621 self.X_main = torchify(self.X_main)
622 self.X_main.retain_grad()
623
624 self.conv_dilation_out = self.conv_dilation(self.X_main)
625 self.conv_dilation_out.retain_grad()
626
627 self.tanh_out = torch.tanh(self.conv_dilation_out)
628 self.sigm_out = torch.sigmoid(self.conv_dilation_out)
629
630 self.tanh_out.retain_grad()
631 self.sigm_out.retain_grad()
632
633 self.multiply_gate_out = self.tanh_out * self.sigm_out
634 self.multiply_gate_out.retain_grad()
635

Callers 1

test_WaveNetModuleFunction · 0.85

Calls

no outgoing calls

Tested by 1

test_WaveNetModuleFunction · 0.68