MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_WGAN_GP_loss

Function test_WGAN_GP_loss

numpy_ml/tests/test_nn.py:189–238  ·  view source on GitHub ↗
(N=5)

Source from the content-addressed store, hash-verified

187
188
189def test_WGAN_GP_loss(N=5):
190 from numpy_ml.neural_nets.losses import WGAN_GPLoss
191
192 np.random.seed(12345)
193
194 N = np.inf if N is None else N
195
196 i = 1
197 while i < N:
198 lambda_ = np.random.randint(0, 10)
199 n_ex = np.random.randint(1, 10)
200 n_feats = np.random.randint(2, 10)
201 Y_real = random_tensor([n_ex], standardize=True)
202 Y_fake = random_tensor([n_ex], standardize=True)
203 gradInterp = random_tensor([n_ex, n_feats], standardize=True)
204
205 mine = WGAN_GPLoss(lambda_=lambda_)
206 C_loss = mine(Y_fake, "C", Y_real, gradInterp)
207 G_loss = mine(Y_fake, "G")
208
209 C_dY_fake, dY_real, dGradInterp = mine.grad(Y_fake, "C", Y_real, gradInterp)
210 G_dY_fake = mine.grad(Y_fake, "G")
211
212 golds = TorchWGANGPLoss(lambda_).extract_grads(Y_real, Y_fake, gradInterp)
213 if np.isnan(golds["C_dGradInterp"]).any():
214 continue
215
216 params = [
217 (Y_real, "Y_real"),
218 (Y_fake, "Y_fake"),
219 (gradInterp, "gradInterp"),
220 (C_loss, "C_loss"),
221 (G_loss, "G_loss"),
222 (-dY_real, "C_dY_real"),
223 (-C_dY_fake, "C_dY_fake"),
224 (dGradInterp, "C_dGradInterp"),
225 (G_dY_fake, "G_dY_fake"),
226 ]
227
228 print("\nTrial {}".format(i))
229 for ix, (mine, label) in enumerate(params):
230 np.testing.assert_allclose(
231 mine,
232 golds[label],
233 err_msg=err_fmt(params, golds, ix),
234 rtol=0.1,
235 atol=1e-2,
236 )
237 print("\tPASSED {}".format(label))
238 i += 1
239
240
241def test_NCELoss(N=1):

Callers

nothing calls this directly

Calls 6

gradMethod · 0.95
random_tensorFunction · 0.90
WGAN_GPLossClass · 0.90
TorchWGANGPLossClass · 0.85
err_fmtFunction · 0.70
extract_gradsMethod · 0.45

Tested by

no test coverage detected