MCPcopy
hub / github.com/dmlc/dgl / process

Method process

python/dgl/data/adapter.py:90–147  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

88 )
89
90 def process(self):
91 is_ogb = hasattr(self.dataset, "get_idx_split")
92 if is_ogb:
93 g, label = self.dataset[0]
94 self.g = g.clone()
95 self.g.ndata["label"] = F.reshape(label, (g.num_nodes(),))
96 else:
97 self.g = self.dataset[0].clone()
98
99 if "label" not in self.g.nodes[self.target_ntype].data:
100 raise ValueError(
101 "Missing node labels. Make sure labels are stored "
102 "under name 'label'."
103 )
104
105 if self.split_ratio is None:
106 if is_ogb:
107 split = self.dataset.get_idx_split()
108 train_idx, val_idx, test_idx = (
109 split["train"],
110 split["valid"],
111 split["test"],
112 )
113 n = self.g.num_nodes()
114 train_mask = utils.generate_mask_tensor(
115 utils.idx2mask(train_idx, n)
116 )
117 val_mask = utils.generate_mask_tensor(
118 utils.idx2mask(val_idx, n)
119 )
120 test_mask = utils.generate_mask_tensor(
121 utils.idx2mask(test_idx, n)
122 )
123 self.g.ndata["train_mask"] = train_mask
124 self.g.ndata["val_mask"] = val_mask
125 self.g.ndata["test_mask"] = test_mask
126 else:
127 assert (
128 "train_mask" in self.g.nodes[self.target_ntype].data
129 ), "train_mask is not provided, please specify split_ratio to generate the masks"
130 assert (
131 "val_mask" in self.g.nodes[self.target_ntype].data
132 ), "val_mask is not provided, please specify split_ratio to generate the masks"
133 assert (
134 "test_mask" in self.g.nodes[self.target_ntype].data
135 ), "test_mask is not provided, please specify split_ratio to generate the masks"
136 else:
137 if self.verbose:
138 print("Generating train/val/test masks...")
139 utils.add_nodepred_split(self, self.split_ratio, self.target_ntype)
140
141 self._set_split_index()
142
143 self.num_classes = getattr(self.dataset, "num_classes", None)
144 if self.num_classes is None:
145 self.num_classes = len(
146 F.unique(self.g.nodes[self.target_ntype].data["label"])
147 )

Callers

nothing calls this directly

Calls 4

_set_split_indexMethod · 0.95
cloneMethod · 0.45
num_nodesMethod · 0.45
get_idx_splitMethod · 0.45

Tested by

no test coverage detected