(
edge_index: Tuple,
edge_weight: Optional[torch.Tensor] = None,
drop_rate: float = 0.5,
renorm: Optional[str] = "sym",
training: bool = False,
)
| 38 | |
| 39 | |
| 40 | def dropout_adj( |
| 41 | edge_index: Tuple, |
| 42 | edge_weight: Optional[torch.Tensor] = None, |
| 43 | drop_rate: float = 0.5, |
| 44 | renorm: Optional[str] = "sym", |
| 45 | training: bool = False, |
| 46 | ): |
| 47 | if not training or drop_rate == 0: |
| 48 | if edge_weight is None: |
| 49 | edge_weight = torch.ones(edge_index[0].shape[0], device=edge_index[0].device) |
| 50 | return edge_index, edge_weight |
| 51 | |
| 52 | if drop_rate < 0.0 or drop_rate > 1.0: |
| 53 | raise ValueError("Dropout probability has to be between 0 and 1, " "but got {}".format(drop_rate)) |
| 54 | |
| 55 | row, col = edge_index |
| 56 | num_nodes = int(max(row.max(), col.max())) + 1 |
| 57 | self_loop = row == col |
| 58 | mask = torch.full((row.shape[0],), 1 - drop_rate, dtype=torch.float, device=row.device) |
| 59 | mask = torch.bernoulli(mask).to(torch.bool) |
| 60 | mask = self_loop | mask |
| 61 | edge_index, edge_weight = filter_adj(row, col, edge_weight, mask) |
| 62 | if renorm == "sym": |
| 63 | edge_weight = symmetric_normalization(num_nodes, edge_index[0], edge_index[1]) |
| 64 | elif renorm == "row": |
| 65 | edge_weight = row_normalization(num_nodes, edge_index[0], edge_index[1]) |
| 66 | return edge_index, edge_weight |
| 67 | |
| 68 | |
| 69 | def dropout_features(x: torch.Tensor, droprate: float, training=True): |
no test coverage detected