MCPcopy
hub / github.com/snap-stanford/GraphGym / train

Function train

graphgym/train.py:49–84  ·  view source on GitHub ↗

r""" The core training pipeline Args: loggers: List of loggers loaders: List of loaders model: GNN model optimizer: PyTorch optimizer scheduler: PyTorch learning rate scheduler

(loggers, loaders, model, optimizer, scheduler)

Source from the content-addressed store, hash-verified

47
48
49def train(loggers, loaders, model, optimizer, scheduler):
50 r"""
51 The core training pipeline
52
53 Args:
54 loggers: List of loggers
55 loaders: List of loaders
56 model: GNN model
57 optimizer: PyTorch optimizer
58 scheduler: PyTorch learning rate scheduler
59
60 """
61 start_epoch = 0
62 if cfg.train.auto_resume:
63 start_epoch = load_ckpt(model, optimizer, scheduler)
64 if start_epoch == cfg.optim.max_epoch:
65 logging.info('Checkpoint found, Task already done')
66 else:
67 logging.info('Start from epoch {}'.format(start_epoch))
68
69 num_splits = len(loggers)
70 for cur_epoch in range(start_epoch, cfg.optim.max_epoch):
71 train_epoch(loggers[0], loaders[0], model, optimizer, scheduler)
72 loggers[0].write_epoch(cur_epoch)
73 if is_eval_epoch(cur_epoch):
74 for i in range(1, num_splits):
75 eval_epoch(loggers[i], loaders[i], model)
76 loggers[i].write_epoch(cur_epoch)
77 if is_ckpt_epoch(cur_epoch):
78 save_ckpt(model, optimizer, scheduler, cur_epoch)
79 for logger in loggers:
80 logger.close()
81 if cfg.train.ckpt_clean:
82 clean_ckpt()
83
84 logging.info('Task done, results saved in {}'.format(cfg.out_dir))

Callers 1

main.pyFile · 0.90

Calls 9

load_ckptFunction · 0.90
is_eval_epochFunction · 0.90
is_ckpt_epochFunction · 0.90
save_ckptFunction · 0.90
clean_ckptFunction · 0.90
write_epochMethod · 0.80
closeMethod · 0.80
train_epochFunction · 0.70
eval_epochFunction · 0.70

Tested by

no test coverage detected