(self, g)
| 1571 | self.dist = Bernoulli(p) |
| 1572 | |
| 1573 | def __call__(self, g): |
| 1574 | g = g.clone() |
| 1575 | |
| 1576 | # Fast path |
| 1577 | if self.p == 0: |
| 1578 | return g |
| 1579 | |
| 1580 | for ntype in g.ntypes: |
| 1581 | samples = self.dist.sample(torch.Size([g.num_nodes(ntype)])) |
| 1582 | nids_to_remove = g.nodes(ntype)[samples.bool().to(g.device)] |
| 1583 | g.remove_nodes(nids_to_remove, ntype=ntype) |
| 1584 | return g |
| 1585 | |
| 1586 | |
| 1587 | # pylint: disable=C0103 |