MCPcopy Index your code
hub / github.com/tensorflow/models / get_activation

Function get_activation

official/modeling/tf_utils.py:87–137  ·  view source on GitHub ↗

Maps an identifier to a Python function, e.g., "relu" => `tf.nn.relu`. It checks string first and if it is one of customized activation not in TF, the corresponding activation will be returned. For non-customized activation names and callable identifiers, always fallback to tf_keras.activatio

(identifier, use_keras_layer=False, **kwargs)

Source from the content-addressed store, hash-verified

85
86
87def get_activation(identifier, use_keras_layer=False, **kwargs):
88 """Maps an identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
89
90 It checks string first and if it is one of customized activation not in TF,
91 the corresponding activation will be returned. For non-customized activation
92 names and callable identifiers, always fallback to tf_keras.activations.get.
93
94 Prefers using keras layers when use_keras_layer=True. Now it only supports
95 'relu', 'linear', 'identity', 'swish', 'mish', 'leaky_relu', and 'gelu'.
96
97 Args:
98 identifier: String name of the activation function or callable.
99 use_keras_layer: If True, use keras layer if identifier is allow-listed.
100 **kwargs: Keyword arguments to use to instantiate an activation function.
101 Available only for 'leaky_relu' and 'gelu' when using keras layers.
102 For example: get_activation('leaky_relu', use_keras_layer=True, alpha=0.1)
103
104 Returns:
105 A Python function corresponding to the activation function or a keras
106 activation layer when use_keras_layer=True.
107 """
108 if isinstance(identifier, six.string_types):
109 identifier = str(identifier).lower()
110 if use_keras_layer:
111 keras_layer_allowlist = {
112 "relu": "relu",
113 "linear": "linear",
114 "identity": "linear",
115 "swish": "swish",
116 "sigmoid": "sigmoid",
117 "relu6": tf.nn.relu6,
118 "leaky_relu": functools.partial(tf.nn.leaky_relu, **kwargs),
119 "hard_swish": activations.hard_swish,
120 "hard_sigmoid": activations.hard_sigmoid,
121 "mish": activations.mish,
122 "gelu": functools.partial(tf.nn.gelu, **kwargs),
123 }
124 if identifier in keras_layer_allowlist:
125 return tf_keras.layers.Activation(keras_layer_allowlist[identifier])
126 name_to_fn = {
127 "gelu": activations.gelu,
128 "simple_swish": activations.simple_swish,
129 "hard_swish": activations.hard_swish,
130 "relu6": activations.relu6,
131 "hard_sigmoid": activations.hard_sigmoid,
132 "identity": activations.identity,
133 "mish": activations.mish,
134 }
135 if identifier in name_to_fn:
136 return tf_keras.activations.get(name_to_fn[identifier])
137 return tf_keras.activations.get(identifier)
138
139
140def get_shape_list(tensor, expected_rank=None, name=None):

Callers

nothing calls this directly

Calls 1

getMethod · 0.45

Tested by

no test coverage detected