MCPcopy
hub / github.com/dmlc/dgl / _random_split

Method _random_split

python/dgl/data/fraud.py:209–238  ·  view source on GitHub ↗

split the dataset into training set, validation set and testing set

(self, x, seed=717, train_size=0.7, val_size=0.1)

Source from the content-addressed store, hash-verified

207 return os.path.exists(graph_path)
208
209 def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1):
210 """split the dataset into training set, validation set and testing set"""
211
212 assert 0 <= train_size + val_size <= 1, (
213 "The sum of valid training set size and validation set size "
214 "must between 0 and 1 (inclusive)."
215 )
216
217 N = x.shape[0]
218 index = np.arange(N)
219 if self.name == "amazon":
220 # 0-3304 are unlabeled nodes
221 index = np.arange(3305, N)
222
223 index = np.random.RandomState(seed).permutation(index)
224 train_idx = index[: int(train_size * len(index))]
225 val_idx = index[len(index) - int(val_size * len(index)) :]
226 test_idx = index[
227 int(train_size * len(index)) : len(index)
228 - int(val_size * len(index))
229 ]
230 train_mask = np.zeros(N, dtype=np.bool_)
231 val_mask = np.zeros(N, dtype=np.bool_)
232 test_mask = np.zeros(N, dtype=np.bool_)
233 train_mask[train_idx] = True
234 val_mask[val_idx] = True
235 test_mask[test_idx] = True
236 self.graph.ndata["train_mask"] = F.tensor(train_mask)
237 self.graph.ndata["val_mask"] = F.tensor(val_mask)
238 self.graph.ndata["test_mask"] = F.tensor(test_mask)
239
240
241class FraudYelpDataset(FraudDataset):

Callers 1

processMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected