| 319 | self.dist = Bernoulli(p) |
| 320 | |
| 321 | def __call__(self, g): |
| 322 | # Fast path |
| 323 | if self.p == 0: |
| 324 | return g |
| 325 | |
| 326 | for node_feat_name in self.node_feat_names: |
| 327 | if isinstance(g.ndata[node_feat_name], torch.Tensor): |
| 328 | feat_mask = self.dist.sample( |
| 329 | torch.Size( |
| 330 | [ |
| 331 | g.ndata[node_feat_name].shape[-1], |
| 332 | ] |
| 333 | ) |
| 334 | ) |
| 335 | g.ndata[node_feat_name][:, feat_mask.bool().to(g.device)] = 0 |
| 336 | |
| 337 | else: |
| 338 | for ntype in g.ndata[node_feat_name].keys(): |
| 339 | mask_shape = g.ndata[node_feat_name][ntype].shape[-1] |
| 340 | feat_mask = self.dist.sample( |
| 341 | torch.Size( |
| 342 | [ |
| 343 | mask_shape, |
| 344 | ] |
| 345 | ) |
| 346 | ) |
| 347 | g.ndata[node_feat_name][ntype][ |
| 348 | :, feat_mask.bool().to(g.device) |
| 349 | ] = 0 |
| 350 | |
| 351 | for edge_feat_name in self.edge_feat_names: |
| 352 | if isinstance(g.edata[edge_feat_name], torch.Tensor): |
| 353 | feat_mask = self.dist.sample( |
| 354 | torch.Size( |
| 355 | [ |
| 356 | g.edata[edge_feat_name].shape[-1], |
| 357 | ] |
| 358 | ) |
| 359 | ) |
| 360 | g.edata[edge_feat_name][:, feat_mask.bool().to(g.device)] = 0 |
| 361 | |
| 362 | else: |
| 363 | for etype in g.edata[edge_feat_name].keys(): |
| 364 | mask_shape = g.edata[edge_feat_name][etype].shape[-1] |
| 365 | feat_mask = self.dist.sample( |
| 366 | torch.Size( |
| 367 | [ |
| 368 | mask_shape, |
| 369 | ] |
| 370 | ) |
| 371 | ) |
| 372 | g.edata[edge_feat_name][etype][ |
| 373 | :, feat_mask.bool().to(g.device) |
| 374 | ] = 0 |
| 375 | return g |
| 376 | |
| 377 | |
| 378 | class RandomWalkPE(BaseTransform): |