Compute the gradient of the VLB with regard to the network parameters. Parameters ---------- y : :py:class:`ndarray ` of shape `(n_ex, N)` The original images. y_pred : :py:class:`ndarray ` of shape `(n_ex, N)`
(y, y_pred, t_mean, t_log_var)
| 294 | |
| 295 | @staticmethod |
| 296 | def grad(y, y_pred, t_mean, t_log_var): |
| 297 | """ |
| 298 | Compute the gradient of the VLB with regard to the network parameters. |
| 299 | |
| 300 | Parameters |
| 301 | ---------- |
| 302 | y : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, N)` |
| 303 | The original images. |
| 304 | y_pred : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, N)` |
| 305 | The VAE reconstruction of the images. |
| 306 | t_mean: :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, T)` |
| 307 | Mean of the variational distribution :math:`q(t | x)`. |
| 308 | t_log_var: :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, T)` |
| 309 | Log of the variance vector of the variational distribution |
| 310 | :math:`q(t | x)`. |
| 311 | |
| 312 | Returns |
| 313 | ------- |
| 314 | dY_pred : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, N)` |
| 315 | The gradient of the VLB with regard to `y_pred`. |
| 316 | dLogVar : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, T)` |
| 317 | The gradient of the VLB with regard to `t_log_var`. |
| 318 | dMean : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, T)` |
| 319 | The gradient of the VLB with regard to `t_mean`. |
| 320 | """ |
| 321 | N = y.shape[0] |
| 322 | eps = np.finfo(float).eps |
| 323 | y_pred = np.clip(y_pred, eps, 1 - eps) |
| 324 | |
| 325 | dY_pred = -y / (N * y_pred) - (y - 1) / (N - N * y_pred) |
| 326 | dLogVar = (np.exp(t_log_var) - 1) / (2 * N) |
| 327 | dMean = t_mean / N |
| 328 | return dY_pred, dLogVar, dMean |
| 329 | |
| 330 | |
| 331 | class WGAN_GPLoss(ObjectiveBase): |
no outgoing calls