(a, b)
| 229 | |
| 230 | # # correctness check |
| 231 | def _print_error(a, b): |
| 232 | for i, (x, y) in enumerate( |
| 233 | zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten()) |
| 234 | ): |
| 235 | if not np.allclose(x, y): |
| 236 | print("@{} {} v.s. {}".format(i, x, y)) |
| 237 | |
| 238 | assert F.allclose(r1, r3) |
| 239 | assert F.allclose(r2, r4) |