https://github.com/pytorch/pytorch/issues/1333 NB: this is only ensures that the convolution out length is the same as the input length IFF stride = 1. Otherwise, in/out lengths will differ.
| 539 | |
| 540 | |
| 541 | class TorchCausalConv1d(torch.nn.Conv1d): |
| 542 | """https://github.com/pytorch/pytorch/issues/1333 |
| 543 | |
| 544 | NB: this is only ensures that the convolution out length is the same as |
| 545 | the input length IFF stride = 1. Otherwise, in/out lengths will differ. |
| 546 | """ |
| 547 | |
| 548 | def __init__( |
| 549 | self, |
| 550 | in_channels, |
| 551 | out_channels, |
| 552 | kernel_size, |
| 553 | stride=1, |
| 554 | dilation=1, |
| 555 | groups=1, |
| 556 | bias=True, |
| 557 | ): |
| 558 | self.__padding = (kernel_size - 1) * dilation |
| 559 | |
| 560 | super(TorchCausalConv1d, self).__init__( |
| 561 | in_channels, |
| 562 | out_channels, |
| 563 | kernel_size=kernel_size, |
| 564 | stride=stride, |
| 565 | padding=self.__padding, |
| 566 | dilation=dilation, |
| 567 | groups=groups, |
| 568 | bias=bias, |
| 569 | ) |
| 570 | |
| 571 | def forward(self, input): |
| 572 | result = super(TorchCausalConv1d, self).forward(input) |
| 573 | if self.__padding != 0: |
| 574 | return result[:, :, : -self.__padding] |
| 575 | return result |
| 576 | |
| 577 | |
| 578 | class TorchWavenetModule(nn.Module): |
no outgoing calls