(self, data, device)
| 29 | self.method = "BGRL" |
| 30 | |
| 31 | def _feature_masking(self, data, device): |
| 32 | feat_mask1 = torch.FloatTensor(data.x.shape[1]).uniform_() > self.p_f1 |
| 33 | feat_mask2 = torch.FloatTensor(data.x.shape[1]).uniform_() > self.p_f2 |
| 34 | feat_mask1, feat_mask2 = feat_mask1.to(device), feat_mask2.to(device) |
| 35 | x1, x2 = data.x.clone(), data.x.clone() |
| 36 | x1, x2 = x1 * feat_mask1, x2 * feat_mask2 |
| 37 | |
| 38 | edge_index1, edge_attr1 = dropout_adj(data.edge_index, data.edge_attr, drop_rate=self.p_e1) |
| 39 | edge_index2, edge_attr2 = dropout_adj(data.edge_index, data.edge_attr, drop_rate=self.p_e2) |
| 40 | |
| 41 | new_data1, new_data2 = data.clone(), data.clone() |
| 42 | new_data1.x, new_data2.x = x1, x2 |
| 43 | new_data1.edge_index, new_data2.edge_index = edge_index1, edge_index2 |
| 44 | new_data1.edge_attr , new_data2.edge_attr = edge_attr1, edge_attr2 |
| 45 | |
| 46 | return new_data1, new_data2 |
| 47 | |
| 48 | def __call__(self, data): |
| 49 |
no test coverage detected