MCPcopy Index your code
hub / github.com/tensorflow/models / _resource_apply_dense

Method _resource_apply_dense

official/modeling/optimization/lamb.py:132–170  ·  view source on GitHub ↗
(self, grad, var, apply_state=None)

Source from the content-addressed store, hash-verified

130 )
131
132 def _resource_apply_dense(self, grad, var, apply_state=None):
133 var_device, var_dtype = var.device, var.dtype.base_dtype
134 coefficients = (apply_state or {}).get(
135 (var_device, var_dtype)
136 ) or self._fallback_apply_state(var_device, var_dtype)
137
138 # m_t = beta1 * m + (1 - beta1) * g_t
139 m = self.get_slot(var, "m")
140 m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"]
141 m_t = m * coefficients["beta_1_t"] + m_scaled_g_values
142 m_t = m.assign(m_t, use_locking=self._use_locking)
143 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
144 v = self.get_slot(var, "v")
145 v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"]
146 v_t = v * coefficients["beta_2_t"] + v_scaled_g_values
147 v_t = v.assign(v_t, use_locking=self._use_locking)
148
149 m_t_hat = m_t / (1.0 - coefficients["beta_1_power"])
150 v_t_hat = v_t / (1.0 - coefficients["beta_2_power"])
151
152 v_sqrt = tf.sqrt(v_t_hat)
153 update = m_t_hat / (v_sqrt + coefficients["epsilon"])
154
155 var_name = self._get_variable_name(var.name)
156 if self._do_use_weight_decay(var_name):
157 update += coefficients["weight_decay_rate"] * var
158
159 ratio = 1.0
160 if self._do_layer_adaptation(var_name):
161 w_norm = tf.norm(var, ord=2)
162 g_norm = tf.norm(update, ord=2)
163 ratio = tf.where(
164 tf.greater(w_norm, 0),
165 tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0),
166 1.0,
167 )
168
169 var_update = var - ratio * coefficients["lr_t"] * update
170 return var.assign(var_update, use_locking=self._use_locking)
171
172 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
173 var_device, var_dtype = var.device, var.dtype.base_dtype

Callers

nothing calls this directly

Calls 5

_get_variable_nameMethod · 0.95
_do_use_weight_decayMethod · 0.95
_do_layer_adaptationMethod · 0.95
getMethod · 0.45
assignMethod · 0.45

Tested by

no test coverage detected