(self)
| 514 | class TestWhile(test_util.TestCase): |
| 515 | @unittest.skip("Skip flaky test.") |
| 516 | def testWhile(self): |
| 517 | with NetBuilder(_use_control_ops=True) as nb: |
| 518 | ops.Copy(ops.Const(0), "i") |
| 519 | ops.Copy(ops.Const(1), "one") |
| 520 | ops.Copy(ops.Const(2), "two") |
| 521 | ops.Copy(ops.Const(2.0), "x") |
| 522 | ops.Copy(ops.Const(3.0), "y") |
| 523 | ops.Copy(ops.Const(2.0), "z") |
| 524 | # raises x to the power of 4 and y to the power of 2 |
| 525 | # and z to the power of 3 |
| 526 | with ops.WhileNet(): |
| 527 | with ops.Condition(): |
| 528 | ops.Add(["i", "one"], "i") |
| 529 | ops.LE(["i", "two"]) |
| 530 | ops.Pow("x", "x", exponent=2.0) |
| 531 | with ops.IfNet(ops.LT(["i", "two"])): |
| 532 | ops.Pow("y", "y", exponent=2.0) |
| 533 | with ops.Else(): |
| 534 | ops.Pow("z", "z", exponent=3.0) |
| 535 | |
| 536 | ops.Add(["x", "y"], "x_plus_y") |
| 537 | ops.Add(["x_plus_y", "z"], "s") |
| 538 | |
| 539 | assert len(nb.get()) == 1, "Expected a single net produced" |
| 540 | net = nb.get()[0] |
| 541 | |
| 542 | net.AddGradientOperators(["s"]) |
| 543 | workspace.RunNetOnce(net) |
| 544 | # (x^4)' = 4x^3 |
| 545 | self.assertAlmostEqual(workspace.FetchBlob("x_grad"), 32) |
| 546 | self.assertAlmostEqual(workspace.FetchBlob("x"), 16) |
| 547 | # (y^2)' = 2y |
| 548 | self.assertAlmostEqual(workspace.FetchBlob("y_grad"), 6) |
| 549 | self.assertAlmostEqual(workspace.FetchBlob("y"), 9) |
| 550 | # (z^3)' = 3z^2 |
| 551 | self.assertAlmostEqual(workspace.FetchBlob("z_grad"), 12) |
| 552 | self.assertAlmostEqual(workspace.FetchBlob("z"), 8) |
| 553 | |
| 554 | |
| 555 | if __name__ == '__main__': |
nothing calls this directly
no test coverage detected