| 39 | |
| 40 | |
| 41 | class ActivationInitializer(object): |
| 42 | def __init__(self, param=None): |
| 43 | """ |
| 44 | A class for initializing activation functions. Valid `param` values |
| 45 | are: |
| 46 | (a) ``__str__`` representations of an `ActivationBase` instance |
| 47 | (b) `ActivationBase` instance |
| 48 | |
| 49 | If `param` is `None`, return the identity function: f(X) = X |
| 50 | """ |
| 51 | self.param = param |
| 52 | |
| 53 | def __call__(self): |
| 54 | """Initialize activation function""" |
| 55 | param = self.param |
| 56 | if param is None: |
| 57 | act = Identity() |
| 58 | elif isinstance(param, ActivationBase): |
| 59 | act = param |
| 60 | elif isinstance(param, str): |
| 61 | act = self.init_from_str(param) |
| 62 | else: |
| 63 | raise ValueError("Unknown activation: {}".format(param)) |
| 64 | return act |
| 65 | |
| 66 | def init_from_str(self, act_str): |
| 67 | """Initialize activation function from the `param` string""" |
| 68 | act_str = act_str.lower() |
| 69 | if act_str == "relu": |
| 70 | act_fn = ReLU() |
| 71 | elif act_str == "tanh": |
| 72 | act_fn = Tanh() |
| 73 | elif act_str == "selu": |
| 74 | act_fn = SELU() |
| 75 | elif act_str == "sigmoid": |
| 76 | act_fn = Sigmoid() |
| 77 | elif act_str == "identity": |
| 78 | act_fn = Identity() |
| 79 | elif act_str == "hardsigmoid": |
| 80 | act_fn = HardSigmoid() |
| 81 | elif act_str == "softplus": |
| 82 | act_fn = SoftPlus() |
| 83 | elif act_str == "exponential": |
| 84 | act_fn = Exponential() |
| 85 | elif "affine" in act_str: |
| 86 | r = r"affine\(slope=(.*), intercept=(.*)\)" |
| 87 | slope, intercept = re.match(r, act_str).groups() |
| 88 | act_fn = Affine(float(slope), float(intercept)) |
| 89 | elif "leaky relu" in act_str: |
| 90 | r = r"leaky relu\(alpha=(.*)\)" |
| 91 | alpha = re.match(r, act_str).groups()[0] |
| 92 | act_fn = LeakyReLU(float(alpha)) |
| 93 | elif "gelu" in act_str: |
| 94 | r = r"gelu\(approximate=(.*)\)" |
| 95 | approx = re.match(r, act_str).groups()[0] == "true" |
| 96 | act_fn = GELU(approximation=approx) |
| 97 | elif "elu" in act_str: |
| 98 | r = r"elu\(alpha=(.*)\)" |
no outgoing calls
no test coverage detected