Args: norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; or a callable that takes a channel number and returns the normalization layer as a nn.Module. Returns: nn.Module or None: the normalization layer
(norm, out_channels)
| 119 | |
| 120 | |
| 121 | def get_norm(norm, out_channels): |
| 122 | """ |
| 123 | Args: |
| 124 | norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; |
| 125 | or a callable that takes a channel number and returns |
| 126 | the normalization layer as a nn.Module. |
| 127 | |
| 128 | Returns: |
| 129 | nn.Module or None: the normalization layer |
| 130 | """ |
| 131 | if norm is None: |
| 132 | return None |
| 133 | if isinstance(norm, str): |
| 134 | if len(norm) == 0: |
| 135 | return None |
| 136 | norm = { |
| 137 | "BN": BatchNorm2d, |
| 138 | # Fixed in https://github.com/pytorch/pytorch/pull/36382 |
| 139 | "SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm, |
| 140 | "FrozenBN": FrozenBatchNorm2d, |
| 141 | "GN": lambda channels: nn.GroupNorm(32, channels), |
| 142 | # for debugging: |
| 143 | "nnSyncBN": nn.SyncBatchNorm, |
| 144 | "naiveSyncBN": NaiveSyncBatchNorm, |
| 145 | # expose stats_mode N as an option to caller, required for zero-len inputs |
| 146 | "naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"), |
| 147 | }[norm] |
| 148 | return norm(out_channels) |
| 149 | |
| 150 | |
| 151 | class NaiveSyncBatchNorm(BatchNorm2d): |
no test coverage detected