(self, rank, model_w, dataset_w)
| 303 | return model_w.to(rank), None |
| 304 | |
| 305 | def train(self, rank, model_w, dataset_w): # noqa: C901 |
| 306 | model_w, _ = self.initialize(model_w, rank=rank, master_addr=self.master_addr, master_port=self.master_port) |
| 307 | self.data_controller.prepare_data_wrapper(dataset_w, rank) |
| 308 | self.eval_data_back_to_cpu = dataset_w.data_back_to_cpu |
| 309 | |
| 310 | optimizers, lr_schedulers = self.build_optimizer(model_w) |
| 311 | if optimizers[0] is None: |
| 312 | return |
| 313 | |
| 314 | est = model_w.set_early_stopping() |
| 315 | if isinstance(est, str): |
| 316 | est_monitor = est |
| 317 | best_index, compare_fn = evaluation_comp(est_monitor) |
| 318 | else: |
| 319 | assert len(est) == 2 |
| 320 | est_monitor, est_compare = est |
| 321 | best_index, compare_fn = evaluation_comp(est_monitor, est_compare) |
| 322 | self.monitor = est_monitor |
| 323 | self.evaluation_metric = model_w.evaluation_metric |
| 324 | |
| 325 | best_model_w = None |
| 326 | |
| 327 | scaler = GradScaler() if self.fp16 else None |
| 328 | |
| 329 | patience = 0 |
| 330 | best_epoch = 0 |
| 331 | for stage in range(self.nstage): |
| 332 | with torch.no_grad(): |
| 333 | pre_stage_out = model_w.pre_stage(stage, dataset_w) |
| 334 | dataset_w.pre_stage(stage, pre_stage_out) |
| 335 | self.data_controller.training_proc_per_stage(dataset_w, rank) |
| 336 | |
| 337 | if self.progress_bar == "epoch": |
| 338 | epoch_iter = tqdm(range(1, self.epochs + 1)) |
| 339 | epoch_printer = Printer(epoch_iter.set_description, rank=rank, world_size=self.world_size) |
| 340 | else: |
| 341 | epoch_iter = range(1, self.epochs + 1) |
| 342 | epoch_printer = Printer(print, rank=rank, world_size=self.world_size) |
| 343 | |
| 344 | self.logger.start() |
| 345 | print_str_dict = dict() |
| 346 | if self.attack is not None: |
| 347 | graph = dataset_w.dataset.data |
| 348 | graph_backup = copy.deepcopy(graph) |
| 349 | graph0 = copy.deepcopy(graph) |
| 350 | num_train = torch.sum(graph.train_mask).item() |
| 351 | for epoch in epoch_iter: |
| 352 | for hook in self.pre_epoch_hooks: |
| 353 | hook(self) |
| 354 | |
| 355 | # inductive setting .. |
| 356 | dataset_w.train() |
| 357 | train_loader = dataset_w.on_train_wrapper() |
| 358 | train_dataset = train_loader.get_dataset_from_loader() |
| 359 | if hasattr(train_dataset, "shuffle"): |
| 360 | train_dataset.shuffle() |
| 361 | training_loss = self.train_step(model_w, train_loader, optimizers, lr_schedulers, rank, scaler) |
| 362 |
no test coverage detected