Create a 1D, 2D, or 3D average pooling module.
(dims, *args, **kwargs)
| 40 | |
| 41 | |
| 42 | def avg_pool_nd(dims, *args, **kwargs): |
| 43 | """ |
| 44 | Create a 1D, 2D, or 3D average pooling module. |
| 45 | """ |
| 46 | if dims == 1: |
| 47 | return nn.AvgPool1d(*args, **kwargs) |
| 48 | elif dims == 2: |
| 49 | return nn.AvgPool2d(*args, **kwargs) |
| 50 | elif dims == 3: |
| 51 | return nn.AvgPool3d(*args, **kwargs) |
| 52 | raise ValueError(f"unsupported dimensions: {dims}") |
| 53 | |
| 54 | |
| 55 | def update_ema(target_params, source_params, rate=0.99): |