(self)
| 88 | ) |
| 89 | |
| 90 | def process(self): |
| 91 | is_ogb = hasattr(self.dataset, "get_idx_split") |
| 92 | if is_ogb: |
| 93 | g, label = self.dataset[0] |
| 94 | self.g = g.clone() |
| 95 | self.g.ndata["label"] = F.reshape(label, (g.num_nodes(),)) |
| 96 | else: |
| 97 | self.g = self.dataset[0].clone() |
| 98 | |
| 99 | if "label" not in self.g.nodes[self.target_ntype].data: |
| 100 | raise ValueError( |
| 101 | "Missing node labels. Make sure labels are stored " |
| 102 | "under name 'label'." |
| 103 | ) |
| 104 | |
| 105 | if self.split_ratio is None: |
| 106 | if is_ogb: |
| 107 | split = self.dataset.get_idx_split() |
| 108 | train_idx, val_idx, test_idx = ( |
| 109 | split["train"], |
| 110 | split["valid"], |
| 111 | split["test"], |
| 112 | ) |
| 113 | n = self.g.num_nodes() |
| 114 | train_mask = utils.generate_mask_tensor( |
| 115 | utils.idx2mask(train_idx, n) |
| 116 | ) |
| 117 | val_mask = utils.generate_mask_tensor( |
| 118 | utils.idx2mask(val_idx, n) |
| 119 | ) |
| 120 | test_mask = utils.generate_mask_tensor( |
| 121 | utils.idx2mask(test_idx, n) |
| 122 | ) |
| 123 | self.g.ndata["train_mask"] = train_mask |
| 124 | self.g.ndata["val_mask"] = val_mask |
| 125 | self.g.ndata["test_mask"] = test_mask |
| 126 | else: |
| 127 | assert ( |
| 128 | "train_mask" in self.g.nodes[self.target_ntype].data |
| 129 | ), "train_mask is not provided, please specify split_ratio to generate the masks" |
| 130 | assert ( |
| 131 | "val_mask" in self.g.nodes[self.target_ntype].data |
| 132 | ), "val_mask is not provided, please specify split_ratio to generate the masks" |
| 133 | assert ( |
| 134 | "test_mask" in self.g.nodes[self.target_ntype].data |
| 135 | ), "test_mask is not provided, please specify split_ratio to generate the masks" |
| 136 | else: |
| 137 | if self.verbose: |
| 138 | print("Generating train/val/test masks...") |
| 139 | utils.add_nodepred_split(self, self.split_ratio, self.target_ntype) |
| 140 | |
| 141 | self._set_split_index() |
| 142 | |
| 143 | self.num_classes = getattr(self.dataset, "num_classes", None) |
| 144 | if self.num_classes is None: |
| 145 | self.num_classes = len( |
| 146 | F.unique(self.g.nodes[self.target_ntype].data["label"]) |
| 147 | ) |
nothing calls this directly
no test coverage detected