MCPcopy
hub / github.com/THUDM/CogDL / preprocessing

Method preprocessing

cogdl/models/nn/sign.py:136–157  ·  view source on GitHub ↗
(self, graph, x)

Source from the content-addressed store, hash-verified

134 return torch.cat(op_embedding, dim=1).to(device)
135
136 def preprocessing(self, graph, x):
137 print("Preprocessing...")
138 dataset_name = None
139 if self.dataset_name is not None:
140 adj_norm = ",".join(self.adj_norm)
141 dataset_name = f"{self.dataset_name}_{self.num_propagations}_{self.diffusion}_{adj_norm}.pt"
142 if os.path.exists(dataset_name):
143 return torch.load(dataset_name).to(x.device)
144 if graph.is_inductive():
145 graph.train()
146 x_train = self._preprocessing(graph, x, drop_edge=True)
147 graph.eval()
148 x_all = self._preprocessing(graph, x, drop_edge=False)
149 train_nid = graph.train_nid
150 x_all[train_nid] = x_train[train_nid]
151 else:
152 x_all = self._preprocessing(graph, x, drop_edge=False)
153
154 if dataset_name is not None:
155 torch.save(x_all.cpu(), dataset_name)
156 print("Preprocessing Done...")
157 return x_all
158
159 def forward(self, graph):
160 if self.cache_x is None:

Callers 1

forwardMethod · 0.95

Calls 5

_preprocessingMethod · 0.95
is_inductiveMethod · 0.80
toMethod · 0.45
trainMethod · 0.45
evalMethod · 0.45

Tested by

no test coverage detected