(self, g, x, mask_rate=0.3)
| 195 | return criterion |
| 196 | |
| 197 | def encoding_mask_noise(self, g, x, mask_rate=0.3): |
| 198 | num_nodes = g.num_nodes |
| 199 | perm = torch.randperm(num_nodes, device=x.device) |
| 200 | num_mask_nodes = int(mask_rate * num_nodes) |
| 201 | |
| 202 | # random masking |
| 203 | num_mask_nodes = int(mask_rate * num_nodes) |
| 204 | mask_nodes = perm[: num_mask_nodes] |
| 205 | keep_nodes = perm[num_mask_nodes: ] |
| 206 | |
| 207 | if self._replace_rate > 0: |
| 208 | num_noise_nodes = int(self._replace_rate * num_mask_nodes) |
| 209 | perm_mask = torch.randperm(num_mask_nodes, device=x.device) |
| 210 | token_nodes = mask_nodes[perm_mask[: int(self._mask_token_rate * num_mask_nodes)]] |
| 211 | noise_nodes = mask_nodes[perm_mask[-int(self._replace_rate * num_mask_nodes):]] |
| 212 | noise_to_be_chosen = torch.randperm(num_nodes, device=x.device)[:num_noise_nodes] |
| 213 | |
| 214 | out_x = x.clone() |
| 215 | out_x[token_nodes] = 0.0 |
| 216 | out_x[noise_nodes] = x[noise_to_be_chosen] |
| 217 | else: |
| 218 | out_x = x.clone() |
| 219 | token_nodes = mask_nodes |
| 220 | out_x[mask_nodes] = 0.0 |
| 221 | |
| 222 | out_x[token_nodes] += self.enc_mask_token |
| 223 | use_g = g.clone() |
| 224 | |
| 225 | return use_g, out_x, (mask_nodes, keep_nodes) |
| 226 | |
| 227 | def forward(self, g, x): |
| 228 | # ---- attribute reconstruction ---- |
no test coverage detected