(name)
| 69 | |
| 70 | |
| 71 | def get_activation(name): |
| 72 | if name is None: |
| 73 | return lambda x: x |
| 74 | name = name.lower() |
| 75 | if name == 'none': |
| 76 | return lambda x: x |
| 77 | elif name.startswith('scale'): |
| 78 | scale_factor = float(name[5:]) |
| 79 | return lambda x: x.clamp(0., scale_factor) / scale_factor |
| 80 | elif name.startswith('clamp'): |
| 81 | clamp_max = float(name[5:]) |
| 82 | return lambda x: x.clamp(0., clamp_max) |
| 83 | elif name.startswith('mul'): |
| 84 | mul_factor = float(name[3:]) |
| 85 | return lambda x: x * mul_factor |
| 86 | elif name == 'lin2srgb': |
| 87 | return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.) |
| 88 | elif name == 'trunc_exp': |
| 89 | return trunc_exp |
| 90 | elif name.startswith('+') or name.startswith('-'): |
| 91 | return lambda x: x + float(name) |
| 92 | elif name == 'sigmoid': |
| 93 | return lambda x: torch.sigmoid(x) |
| 94 | elif name == 'tanh': |
| 95 | return lambda x: torch.tanh(x) |
| 96 | else: |
| 97 | return getattr(F, name) |
| 98 | |
| 99 | |
| 100 | def dot(x, y): |
no outgoing calls
no test coverage detected