| 181 | return [i for i in device_ids], len(device_ids) |
| 182 | |
| 183 | def run(self, model_w: ModelWrapper, dataset_w: DataWrapper): |
| 184 | # for network/graph embedding models |
| 185 | if isinstance(model_w, EmbeddingModelWrapper): |
| 186 | return EmbeddingTrainer(self.save_emb_path, self.load_emb_path).run(model_w, dataset_w) |
| 187 | |
| 188 | print("Model Parameters:", sum(p.numel() for p in model_w.parameters())) |
| 189 | |
| 190 | # for deep learning models |
| 191 | # set default loss_fn and evaluator for model_wrapper |
| 192 | # mainly for in-cogdl setting |
| 193 | model_w.default_loss_fn = dataset_w.get_default_loss_fn() |
| 194 | model_w.default_evaluator = dataset_w.get_default_evaluator() |
| 195 | model_w.set_evaluation_metric() |
| 196 | |
| 197 | if self.resume_training: |
| 198 | model_w = load_model(model_w, self.checkpoint_path).to(self.devices[0]) |
| 199 | |
| 200 | if self.distributed_training: |
| 201 | torch.multiprocessing.set_sharing_strategy("file_system") |
| 202 | self.dist_train(model_w, dataset_w) |
| 203 | else: |
| 204 | self.train(self.devices[0], model_w, dataset_w) |
| 205 | best_model_w = load_model(model_w, self.checkpoint_path).to(self.devices[0]) |
| 206 | |
| 207 | if self.return_model: |
| 208 | return best_model_w.model |
| 209 | |
| 210 | final_test = self.evaluate(best_model_w, dataset_w) |
| 211 | |
| 212 | # clear the GPU memory |
| 213 | dataset = dataset_w.get_dataset() |
| 214 | if isinstance(dataset.data, Graph) or hasattr(dataset.data, "graphs"): |
| 215 | dataset.data.to("cpu") |
| 216 | |
| 217 | return final_test |
| 218 | |
| 219 | def evaluate(self, model_w: ModelWrapper, dataset_w: DataWrapper, cpu=False): |
| 220 | if cpu: |