MCPcopy
hub / github.com/THUDM/CogDL / pretrain

Function pretrain

examples/graphmae/main_inductive.py:142–179  ·  view source on GitHub ↗
(model, dataloaders, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None)

Source from the content-addressed store, hash-verified

140
141
142def pretrain(model, dataloaders, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None):
143 logging.info("start training..")
144 train_loader, val_loader, test_loader, eval_train_loader = dataloaders
145
146 epoch_iter = tqdm(range(max_epoch))
147
148 if isinstance(train_loader, list) and len(train_loader) ==1:
149 train_loader = [train_loader[0].to(device)]
150 eval_train_loader = train_loader
151 if isinstance(val_loader, list) and len(val_loader) == 1:
152 val_loader = [val_loader[0].to(device)]
153 test_loader = val_loader
154
155 for epoch in epoch_iter:
156 model.train()
157 loss_list = []
158
159 for subgraph in train_loader:
160 subgraph = subgraph.to(device)
161 loss, loss_dict = model(subgraph, subgraph.x)
162
163 optimizer.zero_grad()
164 loss.backward()
165 optimizer.step()
166 loss_list.append(loss.item())
167
168 if scheduler is not None:
169 scheduler.step()
170
171 train_loss = np.mean(loss_list)
172 epoch_iter.set_description(f"# Epoch {epoch} | train_loss: {train_loss:.4f}")
173 if logger is not None:
174 loss_dict["lr"] = get_current_lr(optimizer)
175 logger.note(loss_dict, step=epoch)
176
177 if epoch == (max_epoch//2):
178 evaluete(model, (eval_train_loader, val_loader, test_loader), num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob, mute=True)
179 return model
180
181
182def main(args):

Callers 1

mainFunction · 0.70

Calls 8

get_current_lrFunction · 0.90
evalueteFunction · 0.85
toMethod · 0.45
trainMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45
noteMethod · 0.45

Tested by

no test coverage detected