| 140 | |
| 141 | |
| 142 | def pretrain(model, dataloaders, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None): |
| 143 | logging.info("start training..") |
| 144 | train_loader, val_loader, test_loader, eval_train_loader = dataloaders |
| 145 | |
| 146 | epoch_iter = tqdm(range(max_epoch)) |
| 147 | |
| 148 | if isinstance(train_loader, list) and len(train_loader) ==1: |
| 149 | train_loader = [train_loader[0].to(device)] |
| 150 | eval_train_loader = train_loader |
| 151 | if isinstance(val_loader, list) and len(val_loader) == 1: |
| 152 | val_loader = [val_loader[0].to(device)] |
| 153 | test_loader = val_loader |
| 154 | |
| 155 | for epoch in epoch_iter: |
| 156 | model.train() |
| 157 | loss_list = [] |
| 158 | |
| 159 | for subgraph in train_loader: |
| 160 | subgraph = subgraph.to(device) |
| 161 | loss, loss_dict = model(subgraph, subgraph.x) |
| 162 | |
| 163 | optimizer.zero_grad() |
| 164 | loss.backward() |
| 165 | optimizer.step() |
| 166 | loss_list.append(loss.item()) |
| 167 | |
| 168 | if scheduler is not None: |
| 169 | scheduler.step() |
| 170 | |
| 171 | train_loss = np.mean(loss_list) |
| 172 | epoch_iter.set_description(f"# Epoch {epoch} | train_loss: {train_loss:.4f}") |
| 173 | if logger is not None: |
| 174 | loss_dict["lr"] = get_current_lr(optimizer) |
| 175 | logger.note(loss_dict, step=epoch) |
| 176 | |
| 177 | if epoch == (max_epoch//2): |
| 178 | evaluete(model, (eval_train_loader, val_loader, test_loader), num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob, mute=True) |
| 179 | return model |
| 180 | |
| 181 | |
| 182 | def main(args): |