(self, g)
| 1620 | self.dist = Bernoulli(p) |
| 1621 | |
| 1622 | def __call__(self, g): |
| 1623 | g = g.clone() |
| 1624 | |
| 1625 | # Fast path |
| 1626 | if self.p == 0: |
| 1627 | return g |
| 1628 | |
| 1629 | for c_etype in g.canonical_etypes: |
| 1630 | samples = self.dist.sample(torch.Size([g.num_edges(c_etype)])) |
| 1631 | eids_to_remove = g.edges(form="eid", etype=c_etype)[ |
| 1632 | samples.bool().to(g.device) |
| 1633 | ] |
| 1634 | g.remove_edges(eids_to_remove, etype=c_etype) |
| 1635 | return g |
| 1636 | |
| 1637 | |
| 1638 | class AddEdge(BaseTransform): |