()
| 34 | |
| 35 | |
| 36 | def test_whiten(): |
| 37 | a = torch.tensor([[1.0, 2.0, 3.0, 0.0]]) |
| 38 | m = torch.tensor([[1, 1, 1, 0]], dtype=torch.bool) |
| 39 | w = whiten(a, m) |
| 40 | assert abs(w[0, :3].mean().item()) < 1e-5 |
| 41 | assert w[0, 3] == 0 |
| 42 | print("ok whiten: zero-mean over mask, zero outside") |
| 43 | |
| 44 | |
| 45 | def test_ppo_losses(): |