MCPcopy
hub / github.com/ddbourgin/numpy-ml / grad

Method grad

numpy_ml/neural_nets/losses/losses.py:296–328  ·  view source on GitHub ↗

Compute the gradient of the VLB with regard to the network parameters. Parameters ---------- y : :py:class:`ndarray ` of shape `(n_ex, N)` The original images. y_pred : :py:class:`ndarray ` of shape `(n_ex, N)`

(y, y_pred, t_mean, t_log_var)

Source from the content-addressed store, hash-verified

294
295 @staticmethod
296 def grad(y, y_pred, t_mean, t_log_var):
297 """
298 Compute the gradient of the VLB with regard to the network parameters.
299
300 Parameters
301 ----------
302 y : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, N)`
303 The original images.
304 y_pred : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, N)`
305 The VAE reconstruction of the images.
306 t_mean: :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, T)`
307 Mean of the variational distribution :math:`q(t | x)`.
308 t_log_var: :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, T)`
309 Log of the variance vector of the variational distribution
310 :math:`q(t | x)`.
311
312 Returns
313 -------
314 dY_pred : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, N)`
315 The gradient of the VLB with regard to `y_pred`.
316 dLogVar : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, T)`
317 The gradient of the VLB with regard to `t_log_var`.
318 dMean : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, T)`
319 The gradient of the VLB with regard to `t_mean`.
320 """
321 N = y.shape[0]
322 eps = np.finfo(float).eps
323 y_pred = np.clip(y_pred, eps, 1 - eps)
324
325 dY_pred = -y / (N * y_pred) - (y - 1) / (N - N * y_pred)
326 dLogVar = (np.exp(t_log_var) - 1) / (2 * N)
327 dMean = t_mean / N
328 return dY_pred, dLogVar, dMean
329
330
331class WGAN_GPLoss(ObjectiveBase):

Callers 15

test_VAE_lossFunction · 0.95
backwardMethod · 0.45
_bwdMethod · 0.45
_bwdMethod · 0.45
_bwdMethod · 0.45
_bwd2Method · 0.45
_bwdMethod · 0.45
_bwd2Method · 0.45
_bwdMethod · 0.45
_backward_naiveMethod · 0.45
_bwdMethod · 0.45
_backward_naiveMethod · 0.45

Calls

no outgoing calls

Tested by 1

test_VAE_lossFunction · 0.76