Create a 1D, 2D, or 3D convolution module.
(dims, *args, **kwargs)
| 219 | return super().forward(x.float()).type(x.dtype) |
| 220 | |
| 221 | def conv_nd(dims, *args, **kwargs): |
| 222 | """ |
| 223 | Create a 1D, 2D, or 3D convolution module. |
| 224 | """ |
| 225 | if dims == 1: |
| 226 | return nn.Conv1d(*args, **kwargs) |
| 227 | elif dims == 2: |
| 228 | return nn.Conv2d(*args, **kwargs) |
| 229 | elif dims == 3: |
| 230 | return nn.Conv3d(*args, **kwargs) |
| 231 | raise ValueError(f"unsupported dimensions: {dims}") |
| 232 | |
| 233 | |
| 234 | def linear(*args, **kwargs): |