Repurpose a dataset for a standard semi-supervised transductive node prediction task. The class converts a given dataset into a new dataset object such that: - Contains only one graph, accessible from ``dataset[0]``. - The graph stores: - Node labels in ``g.ndata['labe
| 16 | |
| 17 | |
| 18 | class AsNodePredDataset(DGLDataset): |
| 19 | """Repurpose a dataset for a standard semi-supervised transductive |
| 20 | node prediction task. |
| 21 | |
| 22 | The class converts a given dataset into a new dataset object such that: |
| 23 | |
| 24 | - Contains only one graph, accessible from ``dataset[0]``. |
| 25 | - The graph stores: |
| 26 | |
| 27 | - Node labels in ``g.ndata['label']``. |
| 28 | - Train/val/test masks in ``g.ndata['train_mask']``, ``g.ndata['val_mask']``, |
| 29 | and ``g.ndata['test_mask']`` respectively. |
| 30 | - In addition, the dataset contains the following attributes: |
| 31 | |
| 32 | - ``num_classes``, the number of classes to predict. |
| 33 | - ``train_idx``, ``val_idx``, ``test_idx``, train/val/test indexes. |
| 34 | |
| 35 | If the input dataset contains heterogeneous graphs, users need to specify the |
| 36 | ``target_ntype`` argument to indicate which node type to make predictions for. |
| 37 | In this case: |
| 38 | |
| 39 | - Node labels are stored in ``g.nodes[target_ntype].data['label']``. |
| 40 | - Training masks are stored in ``g.nodes[target_ntype].data['train_mask']``. |
| 41 | So do validation and test masks. |
| 42 | |
| 43 | The class will keep only the first graph in the provided dataset and |
| 44 | generate train/val/test masks according to the given split ratio. The generated |
| 45 | masks will be cached to disk for fast re-loading. If the provided split ratio |
| 46 | differs from the cached one, it will re-process the dataset properly. |
| 47 | |
| 48 | Parameters |
| 49 | ---------- |
| 50 | dataset : DGLDataset |
| 51 | The dataset to be converted. |
| 52 | split_ratio : (float, float, float), optional |
| 53 | Split ratios for training, validation and test sets. They must sum to one. |
| 54 | target_ntype : str, optional |
| 55 | The node type to add split mask for. |
| 56 | |
| 57 | Attributes |
| 58 | ---------- |
| 59 | num_classes : int |
| 60 | Number of classes to predict. |
| 61 | train_idx : Tensor |
| 62 | An 1-D integer tensor of training node IDs. |
| 63 | val_idx : Tensor |
| 64 | An 1-D integer tensor of validation node IDs. |
| 65 | test_idx : Tensor |
| 66 | An 1-D integer tensor of test node IDs. |
| 67 | |
| 68 | Examples |
| 69 | -------- |
| 70 | >>> ds = dgl.data.AmazonCoBuyComputerDataset() |
| 71 | >>> print(ds) |
| 72 | Dataset("amazon_co_buy_computer", num_graphs=1, save_path=...) |
| 73 | >>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1]) |
| 74 | >>> print(new_ds) |
| 75 | Dataset("amazon_co_buy_computer-as-nodepred", num_graphs=1, save_path=...) |
no outgoing calls
no test coverage detected