Maps an identifier to a Python function, e.g., "relu" => `tf.nn.relu`. It checks string first and if it is one of customized activation not in TF, the corresponding activation will be returned. For non-customized activation names and callable identifiers, always fallback to tf_keras.activatio
(identifier, use_keras_layer=False, **kwargs)
| 85 | |
| 86 | |
| 87 | def get_activation(identifier, use_keras_layer=False, **kwargs): |
| 88 | """Maps an identifier to a Python function, e.g., "relu" => `tf.nn.relu`. |
| 89 | |
| 90 | It checks string first and if it is one of customized activation not in TF, |
| 91 | the corresponding activation will be returned. For non-customized activation |
| 92 | names and callable identifiers, always fallback to tf_keras.activations.get. |
| 93 | |
| 94 | Prefers using keras layers when use_keras_layer=True. Now it only supports |
| 95 | 'relu', 'linear', 'identity', 'swish', 'mish', 'leaky_relu', and 'gelu'. |
| 96 | |
| 97 | Args: |
| 98 | identifier: String name of the activation function or callable. |
| 99 | use_keras_layer: If True, use keras layer if identifier is allow-listed. |
| 100 | **kwargs: Keyword arguments to use to instantiate an activation function. |
| 101 | Available only for 'leaky_relu' and 'gelu' when using keras layers. |
| 102 | For example: get_activation('leaky_relu', use_keras_layer=True, alpha=0.1) |
| 103 | |
| 104 | Returns: |
| 105 | A Python function corresponding to the activation function or a keras |
| 106 | activation layer when use_keras_layer=True. |
| 107 | """ |
| 108 | if isinstance(identifier, six.string_types): |
| 109 | identifier = str(identifier).lower() |
| 110 | if use_keras_layer: |
| 111 | keras_layer_allowlist = { |
| 112 | "relu": "relu", |
| 113 | "linear": "linear", |
| 114 | "identity": "linear", |
| 115 | "swish": "swish", |
| 116 | "sigmoid": "sigmoid", |
| 117 | "relu6": tf.nn.relu6, |
| 118 | "leaky_relu": functools.partial(tf.nn.leaky_relu, **kwargs), |
| 119 | "hard_swish": activations.hard_swish, |
| 120 | "hard_sigmoid": activations.hard_sigmoid, |
| 121 | "mish": activations.mish, |
| 122 | "gelu": functools.partial(tf.nn.gelu, **kwargs), |
| 123 | } |
| 124 | if identifier in keras_layer_allowlist: |
| 125 | return tf_keras.layers.Activation(keras_layer_allowlist[identifier]) |
| 126 | name_to_fn = { |
| 127 | "gelu": activations.gelu, |
| 128 | "simple_swish": activations.simple_swish, |
| 129 | "hard_swish": activations.hard_swish, |
| 130 | "relu6": activations.relu6, |
| 131 | "hard_sigmoid": activations.hard_sigmoid, |
| 132 | "identity": activations.identity, |
| 133 | "mish": activations.mish, |
| 134 | } |
| 135 | if identifier in name_to_fn: |
| 136 | return tf_keras.activations.get(name_to_fn[identifier]) |
| 137 | return tf_keras.activations.get(identifier) |
| 138 | |
| 139 | |
| 140 | def get_shape_list(tensor, expected_rank=None, name=None): |