MCPcopy Index your code
hub / github.com/dmlc/dgl / AsNodePredDataset

Class AsNodePredDataset

python/dgl/data/adapter.py:18–204  ·  view source on GitHub ↗

Repurpose a dataset for a standard semi-supervised transductive node prediction task. The class converts a given dataset into a new dataset object such that: - Contains only one graph, accessible from ``dataset[0]``. - The graph stores: - Node labels in ``g.ndata['labe

Source from the content-addressed store, hash-verified

16
17
18class AsNodePredDataset(DGLDataset):
19 """Repurpose a dataset for a standard semi-supervised transductive
20 node prediction task.
21
22 The class converts a given dataset into a new dataset object such that:
23
24 - Contains only one graph, accessible from ``dataset[0]``.
25 - The graph stores:
26
27 - Node labels in ``g.ndata['label']``.
28 - Train/val/test masks in ``g.ndata['train_mask']``, ``g.ndata['val_mask']``,
29 and ``g.ndata['test_mask']`` respectively.
30 - In addition, the dataset contains the following attributes:
31
32 - ``num_classes``, the number of classes to predict.
33 - ``train_idx``, ``val_idx``, ``test_idx``, train/val/test indexes.
34
35 If the input dataset contains heterogeneous graphs, users need to specify the
36 ``target_ntype`` argument to indicate which node type to make predictions for.
37 In this case:
38
39 - Node labels are stored in ``g.nodes[target_ntype].data['label']``.
40 - Training masks are stored in ``g.nodes[target_ntype].data['train_mask']``.
41 So do validation and test masks.
42
43 The class will keep only the first graph in the provided dataset and
44 generate train/val/test masks according to the given split ratio. The generated
45 masks will be cached to disk for fast re-loading. If the provided split ratio
46 differs from the cached one, it will re-process the dataset properly.
47
48 Parameters
49 ----------
50 dataset : DGLDataset
51 The dataset to be converted.
52 split_ratio : (float, float, float), optional
53 Split ratios for training, validation and test sets. They must sum to one.
54 target_ntype : str, optional
55 The node type to add split mask for.
56
57 Attributes
58 ----------
59 num_classes : int
60 Number of classes to predict.
61 train_idx : Tensor
62 An 1-D integer tensor of training node IDs.
63 val_idx : Tensor
64 An 1-D integer tensor of validation node IDs.
65 test_idx : Tensor
66 An 1-D integer tensor of test node IDs.
67
68 Examples
69 --------
70 >>> ds = dgl.data.AmazonCoBuyComputerDataset()
71 >>> print(ds)
72 Dataset("amazon_co_buy_computer", num_graphs=1, save_path=...)
73 >>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1])
74 >>> print(new_ds)
75 Dataset("amazon_co_buy_computer-as-nodepred", num_graphs=1, save_path=...)

Calls

no outgoing calls

Tested by

no test coverage detected