MCPcopy Index your code
hub / github.com/pytorch/pytorch / testIf

Method testIf

caffe2/python/gradient_check_test.py:410–511  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

408
409class TestIf(test_util.TestCase):
410 def testIf(self):
411 W_a_values = [2.0, 1.5]
412 B_a_values = [0.5]
413 W_b_values = [7.0, 3.5]
414 B_b_values = [1.5]
415
416 with NetBuilder(_use_control_ops=True) as init_nb:
417 W_a = ops.UniformFill([], "W_a", shape=[1, 2], min=-1., max=1.)
418 B_a = ops.ConstantFill([], "B_a", shape=[1], value=0.0)
419 W_b = ops.UniformFill([], "W_b", shape=[1, 2], min=-1., max=1.)
420 B_b = ops.ConstantFill([], "B_b", shape=[1], value=0.0)
421
422 W_gt_a = ops.GivenTensorFill(
423 [], "W_gt_a", shape=[1, 2], values=W_a_values)
424 B_gt_a = ops.GivenTensorFill([], "B_gt_a", shape=[1], values=B_a_values)
425 W_gt_b = ops.GivenTensorFill(
426 [], "W_gt_b", shape=[1, 2], values=W_b_values)
427 B_gt_b = ops.GivenTensorFill([], "B_gt_b", shape=[1], values=B_b_values)
428
429 params = [W_gt_a, B_gt_a, W_a, B_a, W_gt_b, B_gt_b, W_b, B_b]
430
431 with NetBuilder(_use_control_ops=True, initial_scope=params) as train_nb:
432 Y_pred = ops.ConstantFill([], "Y_pred", shape=[1], value=0.0)
433 Y_noise = ops.ConstantFill([], "Y_noise", shape=[1], value=0.0)
434
435 switch = ops.UniformFill(
436 [], "switch", shape=[1], min=-1., max=1., run_once=0)
437 zero = ops.ConstantFill([], "zero", shape=[1], value=0.0)
438 X = ops.GaussianFill(
439 [], "X", shape=[4096, 2], mean=0.0, std=1.0, run_once=0)
440 noise = ops.GaussianFill(
441 [], "noise", shape=[4096, 1], mean=0.0, std=1.0, run_once=0)
442
443 with ops.IfNet(ops.LT([switch, zero])):
444 Y_gt = ops.FC([X, W_gt_a, B_gt_a], "Y_gt")
445 ops.Add([Y_gt, noise], Y_noise)
446 ops.FC([X, W_a, B_a], Y_pred)
447 with ops.Else():
448 Y_gt = ops.FC([X, W_gt_b, B_gt_b], "Y_gt")
449 ops.Add([Y_gt, noise], Y_noise)
450 ops.FC([X, W_b, B_b], Y_pred)
451
452 dist = ops.SquaredL2Distance([Y_noise, Y_pred], "dist")
453 loss = dist.AveragedLoss([], ["loss"])
454
455 assert len(init_nb.get()) == 1, "Expected a single init net produced"
456 assert len(train_nb.get()) == 1, "Expected a single train net produced"
457
458 train_net = train_nb.get()[0]
459 gradient_map = train_net.AddGradientOperators([loss])
460
461 init_net = init_nb.get()[0]
462 ITER = init_net.ConstantFill(
463 [], "ITER", shape=[1], value=0, dtype=core.DataType.INT64)
464 train_net.Iter(ITER, ITER)
465 LR = train_net.LearningRate(ITER, "LR", base_lr=-0.1,
466 policy="step", stepsize=20, gamma=0.9)
467 ONE = init_net.ConstantFill([], "ONE", shape=[1], value=1.)

Callers

nothing calls this directly

Calls 13

NetBuilderClass · 0.90
IfNetMethod · 0.80
FCMethod · 0.80
IterMethod · 0.80
rangeFunction · 0.50
absFunction · 0.50
AddMethod · 0.45
ElseMethod · 0.45
getMethod · 0.45
AddGradientOperatorsMethod · 0.45
ProtoMethod · 0.45
itemsMethod · 0.45

Tested by

no test coverage detected