training function
(dataset, model, prog_args, same_feat=True, val_dataset=None)
| 240 | |
| 241 | |
| 242 | def train(dataset, model, prog_args, same_feat=True, val_dataset=None): |
| 243 | """ |
| 244 | training function |
| 245 | """ |
| 246 | dir = prog_args.save_dir + "/" + prog_args.dataset |
| 247 | if not os.path.exists(dir): |
| 248 | os.makedirs(dir) |
| 249 | dataloader = dataset |
| 250 | optimizer = torch.optim.Adam( |
| 251 | filter(lambda p: p.requires_grad, model.parameters()), lr=0.001 |
| 252 | ) |
| 253 | early_stopping_logger = {"best_epoch": -1, "val_acc": -1} |
| 254 | |
| 255 | if prog_args.cuda > 0: |
| 256 | torch.cuda.set_device(0) |
| 257 | for epoch in range(prog_args.epoch): |
| 258 | begin_time = time.time() |
| 259 | model.train() |
| 260 | accum_correct = 0 |
| 261 | total = 0 |
| 262 | print("\nEPOCH ###### {} ######".format(epoch)) |
| 263 | computation_time = 0.0 |
| 264 | for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader): |
| 265 | for key, value in batch_graph.ndata.items(): |
| 266 | batch_graph.ndata[key] = value.float() |
| 267 | graph_labels = graph_labels.long() |
| 268 | if torch.cuda.is_available(): |
| 269 | batch_graph = batch_graph.to(torch.cuda.current_device()) |
| 270 | graph_labels = graph_labels.cuda() |
| 271 | |
| 272 | model.zero_grad() |
| 273 | compute_start = time.time() |
| 274 | ypred = model(batch_graph) |
| 275 | indi = torch.argmax(ypred, dim=1) |
| 276 | correct = torch.sum(indi == graph_labels).item() |
| 277 | accum_correct += correct |
| 278 | total += graph_labels.size()[0] |
| 279 | loss = model.loss(ypred, graph_labels) |
| 280 | loss.backward() |
| 281 | batch_compute_time = time.time() - compute_start |
| 282 | computation_time += batch_compute_time |
| 283 | nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip) |
| 284 | optimizer.step() |
| 285 | |
| 286 | train_accu = accum_correct / total |
| 287 | print( |
| 288 | "train accuracy for this epoch {} is {:.2f}%".format( |
| 289 | epoch, train_accu * 100 |
| 290 | ) |
| 291 | ) |
| 292 | elapsed_time = time.time() - begin_time |
| 293 | print( |
| 294 | "loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format( |
| 295 | loss.item(), elapsed_time, computation_time |
| 296 | ) |
| 297 | ) |
| 298 | global_train_time_per_epoch.append(elapsed_time) |
| 299 | if val_dataset is not None: |
no test coverage detected