MCPcopy Index your code
hub / github.com/dmlc/dgl / add_nodepred_split

Function add_nodepred_split

python/dgl/data/utils.py:445–495  ·  view source on GitHub ↗

Split the given dataset into training, validation and test sets for transductive node predction task. It adds three node mask arrays ``'train_mask'``, ``'val_mask'`` and ``'test_mask'``, to each graph in the dataset. Each sample in the dataset thus must be a :class:`DGLGraph`. Fix

(dataset, ratio, ntype=None)

Source from the content-addressed store, hash-verified

443
444
445def add_nodepred_split(dataset, ratio, ntype=None):
446 """Split the given dataset into training, validation and test sets for
447 transductive node predction task.
448
449 It adds three node mask arrays ``'train_mask'``, ``'val_mask'`` and ``'test_mask'``,
450 to each graph in the dataset. Each sample in the dataset thus must be a :class:`DGLGraph`.
451
452 Fix the random seed of NumPy to make the result deterministic::
453
454 numpy.random.seed(42)
455
456 Parameters
457 ----------
458 dataset : DGLDataset
459 The dataset to modify.
460 ratio : (float, float, float)
461 Split ratios for training, validation and test sets. Must sum to one.
462 ntype : str, optional
463 The node type to add mask for.
464
465 Examples
466 --------
467 >>> dataset = dgl.data.AmazonCoBuyComputerDataset()
468 >>> print('train_mask' in dataset[0].ndata)
469 False
470 >>> dgl.data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
471 >>> print('train_mask' in dataset[0].ndata)
472 True
473 """
474 if len(ratio) != 3:
475 raise ValueError(
476 f"Split ratio must be a float triplet but got {ratio}."
477 )
478 for i in range(len(dataset)):
479 g = dataset[i]
480 n = g.num_nodes(ntype)
481 idx = np.arange(0, n)
482 np.random.shuffle(idx)
483 n_train, n_val, n_test = (
484 int(n * ratio[0]),
485 int(n * ratio[1]),
486 int(n * ratio[2]),
487 )
488 train_mask = generate_mask_tensor(idx2mask(idx[:n_train], n))
489 val_mask = generate_mask_tensor(
490 idx2mask(idx[n_train : n_train + n_val], n)
491 )
492 test_mask = generate_mask_tensor(idx2mask(idx[n_train + n_val :], n))
493 g.nodes[ntype].data["train_mask"] = train_mask
494 g.nodes[ntype].data["val_mask"] = val_mask
495 g.nodes[ntype].data["test_mask"] = test_mask
496
497
498def mask_nodes_by_property(property_values, part_ratios, random_seed=None):

Callers

nothing calls this directly

Calls 4

generate_mask_tensorFunction · 0.85
idx2maskFunction · 0.85
num_nodesMethod · 0.45
shuffleMethod · 0.45

Tested by

no test coverage detected