Create a 1D, 2D, or 3D convolution module.
(dims, *args, **kwargs)
| 20 | |
| 21 | |
| 22 | def conv_nd(dims, *args, **kwargs): |
| 23 | """ |
| 24 | Create a 1D, 2D, or 3D convolution module. |
| 25 | """ |
| 26 | if dims == 1: |
| 27 | return nn.Conv1d(*args, **kwargs) |
| 28 | elif dims == 2: |
| 29 | return nn.Conv2d(*args, **kwargs) |
| 30 | elif dims == 3: |
| 31 | return nn.Conv3d(*args, **kwargs) |
| 32 | raise ValueError(f"unsupported dimensions: {dims}") |
| 33 | |
| 34 | |
| 35 | def linear(*args, **kwargs): |