| 1772 | |
| 1773 | |
| 1774 | class ModelConv1DKernel1(nn.Module): |
| 1775 | def __init__(self): |
| 1776 | super().__init__() |
| 1777 | self.conv1d = nn.Conv1d(in_channels=3, out_channels=10, kernel_size=1) |
| 1778 | self.relu = nn.ReLU() |
| 1779 | self.flat = nn.Flatten() |
| 1780 | self.lin0 = nn.Linear(10 * 10, 2) |
| 1781 | self.dtype = torch.float |
| 1782 | |
| 1783 | def forward(self, x): |
| 1784 | x = x.to(self.dtype) |
| 1785 | x = x.reshape(-1, 3, 10) # batch, channels, seq_len |
| 1786 | x = self.conv1d(x) |
| 1787 | x = self.relu(x) |
| 1788 | x = self.flat(x) |
| 1789 | x = self.lin0(x) |
| 1790 | return x |
| 1791 | |
| 1792 | |
| 1793 | class ModelConv3D(nn.Module): |
no outgoing calls
searching dependent graphs…