| 22 | ) |
| 23 | @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet") |
| 24 | def test_add_nodepred_split(): |
| 25 | dataset = data.AmazonCoBuyComputerDataset() |
| 26 | print("train_mask" in dataset[0].ndata) |
| 27 | data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1]) |
| 28 | assert "train_mask" in dataset[0].ndata |
| 29 | |
| 30 | dataset = data.AIFBDataset() |
| 31 | print("train_mask" in dataset[0].nodes["Publikationen"].data) |
| 32 | data.utils.add_nodepred_split( |
| 33 | dataset, [0.8, 0.1, 0.1], ntype="Publikationen" |
| 34 | ) |
| 35 | assert "train_mask" in dataset[0].nodes["Publikationen"].data |
| 36 | |
| 37 | |
| 38 | @unittest.skipIf( |