()
| 13 | |
| 14 | |
| 15 | def test_keepdims_wrapper_no_axis(): |
| 16 | def summer(a, axis=None): |
| 17 | return a.sum(axis=axis) |
| 18 | |
| 19 | summer_wrapped = keepdims_wrapper(summer) |
| 20 | |
| 21 | assert summer_wrapped != summer |
| 22 | |
| 23 | a = np.arange(24).reshape(1, 2, 3, 4) |
| 24 | |
| 25 | r = summer(a) |
| 26 | rw = summer_wrapped(a, keepdims=True) |
| 27 | rwf = summer_wrapped(a, keepdims=False) |
| 28 | |
| 29 | assert r.ndim == 0 |
| 30 | assert r.shape == tuple() |
| 31 | assert r == 276 |
| 32 | |
| 33 | assert rw.ndim == 4 |
| 34 | assert rw.shape == (1, 1, 1, 1) |
| 35 | assert (rw == 276).all() |
| 36 | |
| 37 | assert rwf.ndim == 0 |
| 38 | assert rwf.shape == tuple() |
| 39 | assert rwf == 276 |
| 40 | |
| 41 | |
| 42 | def test_keepdims_wrapper_one_axis(): |
nothing calls this directly
no test coverage detected
searching dependent graphs…