(self, grad, var, apply_state=None)
| 130 | ) |
| 131 | |
| 132 | def _resource_apply_dense(self, grad, var, apply_state=None): |
| 133 | var_device, var_dtype = var.device, var.dtype.base_dtype |
| 134 | coefficients = (apply_state or {}).get( |
| 135 | (var_device, var_dtype) |
| 136 | ) or self._fallback_apply_state(var_device, var_dtype) |
| 137 | |
| 138 | # m_t = beta1 * m + (1 - beta1) * g_t |
| 139 | m = self.get_slot(var, "m") |
| 140 | m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] |
| 141 | m_t = m * coefficients["beta_1_t"] + m_scaled_g_values |
| 142 | m_t = m.assign(m_t, use_locking=self._use_locking) |
| 143 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) |
| 144 | v = self.get_slot(var, "v") |
| 145 | v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"] |
| 146 | v_t = v * coefficients["beta_2_t"] + v_scaled_g_values |
| 147 | v_t = v.assign(v_t, use_locking=self._use_locking) |
| 148 | |
| 149 | m_t_hat = m_t / (1.0 - coefficients["beta_1_power"]) |
| 150 | v_t_hat = v_t / (1.0 - coefficients["beta_2_power"]) |
| 151 | |
| 152 | v_sqrt = tf.sqrt(v_t_hat) |
| 153 | update = m_t_hat / (v_sqrt + coefficients["epsilon"]) |
| 154 | |
| 155 | var_name = self._get_variable_name(var.name) |
| 156 | if self._do_use_weight_decay(var_name): |
| 157 | update += coefficients["weight_decay_rate"] * var |
| 158 | |
| 159 | ratio = 1.0 |
| 160 | if self._do_layer_adaptation(var_name): |
| 161 | w_norm = tf.norm(var, ord=2) |
| 162 | g_norm = tf.norm(update, ord=2) |
| 163 | ratio = tf.where( |
| 164 | tf.greater(w_norm, 0), |
| 165 | tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), |
| 166 | 1.0, |
| 167 | ) |
| 168 | |
| 169 | var_update = var - ratio * coefficients["lr_t"] * update |
| 170 | return var.assign(var_update, use_locking=self._use_locking) |
| 171 | |
| 172 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None): |
| 173 | var_device, var_dtype = var.device, var.dtype.base_dtype |
nothing calls this directly
no test coverage detected