(self)
| 61 | self._scheduler = optim.lr_scheduler.LambdaLR(self._optimizer, lr_lambda=lr_scheduler) |
| 62 | |
| 63 | def train(self): |
| 64 | # get initial test results |
| 65 | print(self._args) |
| 66 | print("start training!") |
| 67 | |
| 68 | print("Initial Evaluation...") |
| 69 | self.infer_embeddings() |
| 70 | test_best, test_std_best = self.evaluate() |
| 71 | print("test: {:.4f}".format(test_best)) |
| 72 | |
| 73 | # start training |
| 74 | self._model.train() |
| 75 | for epoch in range(self._args.epochs): |
| 76 | |
| 77 | self._dataset.to(self._device) |
| 78 | |
| 79 | augmentation = utils.Augmentation(float(self._args.aug_params[0]), float(self._args.aug_params[1]), float(self._args.aug_params[2]), float(self._args.aug_params[3])) |
| 80 | view1, view2 = augmentation._feature_masking(self._dataset, self._device) |
| 81 | |
| 82 | v1_output, v2_output, loss = self._model( |
| 83 | x1=view1.x, x2=view2.x, graph_v1=view1, graph_v2=view2, |
| 84 | edge_weight_v1=view1.edge_attr, edge_weight_v2=view2.edge_attr) |
| 85 | |
| 86 | self._optimizer.zero_grad() |
| 87 | loss.backward() |
| 88 | self._optimizer.step() |
| 89 | self._scheduler.step() |
| 90 | self._model.update_moving_average() |
| 91 | sys.stdout.write('\rEpoch {}/{}, loss {:.4f}, lr {}'.format(epoch + 1, self._args.epochs, loss.data, self._optimizer.param_groups[0]['lr'])) |
| 92 | sys.stdout.flush() |
| 93 | |
| 94 | if (epoch + 1) % self._args.cache_step == 0: |
| 95 | print("") |
| 96 | print("\nEvaluating {}th epoch..".format(epoch + 1)) |
| 97 | |
| 98 | self.infer_embeddings() |
| 99 | test_acc, test_std = self.evaluate() |
| 100 | |
| 101 | self.writer.add_scalar("stats/learning_rate", self._optimizer.param_groups[0]["lr"] , epoch + 1) |
| 102 | self.writer.add_scalar("accs/test_acc", test_acc, epoch + 1) |
| 103 | print("test: {:.4f} \n".format(test_acc)) |
| 104 | |
| 105 | print() |
| 106 | print("Training Done!") |
| 107 | |
| 108 | def infer_embeddings(self): |
| 109 |
no test coverage detected