r""" The core training pipeline Args: loggers: List of loggers loaders: List of loaders model: GNN model optimizer: PyTorch optimizer scheduler: PyTorch learning rate scheduler
(loggers, loaders, model, optimizer, scheduler)
| 47 | |
| 48 | |
| 49 | def train(loggers, loaders, model, optimizer, scheduler): |
| 50 | r""" |
| 51 | The core training pipeline |
| 52 | |
| 53 | Args: |
| 54 | loggers: List of loggers |
| 55 | loaders: List of loaders |
| 56 | model: GNN model |
| 57 | optimizer: PyTorch optimizer |
| 58 | scheduler: PyTorch learning rate scheduler |
| 59 | |
| 60 | """ |
| 61 | start_epoch = 0 |
| 62 | if cfg.train.auto_resume: |
| 63 | start_epoch = load_ckpt(model, optimizer, scheduler) |
| 64 | if start_epoch == cfg.optim.max_epoch: |
| 65 | logging.info('Checkpoint found, Task already done') |
| 66 | else: |
| 67 | logging.info('Start from epoch {}'.format(start_epoch)) |
| 68 | |
| 69 | num_splits = len(loggers) |
| 70 | for cur_epoch in range(start_epoch, cfg.optim.max_epoch): |
| 71 | train_epoch(loggers[0], loaders[0], model, optimizer, scheduler) |
| 72 | loggers[0].write_epoch(cur_epoch) |
| 73 | if is_eval_epoch(cur_epoch): |
| 74 | for i in range(1, num_splits): |
| 75 | eval_epoch(loggers[i], loaders[i], model) |
| 76 | loggers[i].write_epoch(cur_epoch) |
| 77 | if is_ckpt_epoch(cur_epoch): |
| 78 | save_ckpt(model, optimizer, scheduler, cur_epoch) |
| 79 | for logger in loggers: |
| 80 | logger.close() |
| 81 | if cfg.train.ckpt_clean: |
| 82 | clean_ckpt() |
| 83 | |
| 84 | logging.info('Task done, results saved in {}'.format(cfg.out_dir)) |
no test coverage detected