(activation)
| 30 | |
| 31 | |
| 32 | def get_activation(activation): |
| 33 | if activation == 'silu': |
| 34 | return torch.nn.SiLU() |
| 35 | elif activation == 'gelu_jit': |
| 36 | return GELUJit() |
| 37 | elif activation == 'gelu': |
| 38 | return torch.nn.GELU() |
| 39 | elif activation == 'none': |
| 40 | return torch.nn.Identity() |
| 41 | else: |
| 42 | raise ValueError(f'unknown activation type {activation}') |
| 43 | |
| 44 | |
| 45 | class GroupNorm32(nn.GroupNorm): |