(a, b)
| 105 | |
| 106 | # correctness check |
| 107 | def _print_error(a, b): |
| 108 | for i, (x, y) in enumerate( |
| 109 | zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten()) |
| 110 | ): |
| 111 | if not np.allclose(x, y): |
| 112 | print("@{} {} v.s. {}".format(i, x, y)) |
| 113 | |
| 114 | if not F.allclose(r1, r2): |
| 115 | _print_error(r1, r2) |