Create a 1D, 2D, or 3D convolution module.
(dims, *args, **kwargs)
| 276 | |
| 277 | |
| 278 | def conv_nd(dims, *args, **kwargs): |
| 279 | """ |
| 280 | Create a 1D, 2D, or 3D convolution module. |
| 281 | """ |
| 282 | if dims == 1: |
| 283 | return nn.Conv1d(*args, **kwargs) |
| 284 | elif dims == 2: |
| 285 | return nn.Conv2d(*args, **kwargs) |
| 286 | elif dims == 3: |
| 287 | return nn.Conv3d(*args, **kwargs) |
| 288 | raise ValueError(f"unsupported dimensions: {dims}") |
| 289 | |
| 290 | |
| 291 | def linear(*args, **kwargs): |