(non_blocking)
| 23 | @unittest.skipIf(F._default_context_str != "gpu", "CopyTo needs GPU to test") |
| 24 | @pytest.mark.parametrize("non_blocking", [False, True]) |
| 25 | def test_CopyTo(non_blocking): |
| 26 | item_sampler = gb.ItemSampler( |
| 27 | gb.ItemSet(torch.arange(20), names="seeds"), 4 |
| 28 | ) |
| 29 | if non_blocking: |
| 30 | item_sampler = item_sampler.transform(lambda x: x.pin_memory()) |
| 31 | |
| 32 | # Invoke CopyTo via class constructor. |
| 33 | dp = gb.CopyTo(item_sampler, "cuda") |
| 34 | for data in dp: |
| 35 | assert data.seeds.device.type == "cuda" |
| 36 | |
| 37 | dp = gb.CopyTo(item_sampler, "cuda", non_blocking) |
| 38 | for data in dp: |
| 39 | assert data.seeds.device.type == "cuda" |
| 40 | |
| 41 | # Invoke CopyTo via functional form. |
| 42 | dp = item_sampler.copy_to("cuda", non_blocking) |
| 43 | for data in dp: |
| 44 | assert data.seeds.device.type == "cuda" |
| 45 | |
| 46 | |
| 47 | @pytest.mark.parametrize( |
nothing calls this directly
no test coverage detected