MCPcopy
hub / github.com/THUDM/CogDL / _feature_masking

Method _feature_masking

examples/bgrl/utils.py:31–46  ·  view source on GitHub ↗
(self, data, device)

Source from the content-addressed store, hash-verified

29 self.method = "BGRL"
30
31 def _feature_masking(self, data, device):
32 feat_mask1 = torch.FloatTensor(data.x.shape[1]).uniform_() > self.p_f1
33 feat_mask2 = torch.FloatTensor(data.x.shape[1]).uniform_() > self.p_f2
34 feat_mask1, feat_mask2 = feat_mask1.to(device), feat_mask2.to(device)
35 x1, x2 = data.x.clone(), data.x.clone()
36 x1, x2 = x1 * feat_mask1, x2 * feat_mask2
37
38 edge_index1, edge_attr1 = dropout_adj(data.edge_index, data.edge_attr, drop_rate=self.p_e1)
39 edge_index2, edge_attr2 = dropout_adj(data.edge_index, data.edge_attr, drop_rate=self.p_e2)
40
41 new_data1, new_data2 = data.clone(), data.clone()
42 new_data1.x, new_data2.x = x1, x2
43 new_data1.edge_index, new_data2.edge_index = edge_index1, edge_index2
44 new_data1.edge_attr , new_data2.edge_attr = edge_attr1, edge_attr2
45
46 return new_data1, new_data2
47
48 def __call__(self, data):
49

Callers 2

trainMethod · 0.95
__call__Method · 0.95

Calls 3

dropout_adjFunction · 0.90
toMethod · 0.45
cloneMethod · 0.45

Tested by

no test coverage detected