Function
train
(model, optimizer, loader)
Source from the content-addressed store, hash-verified
| 121 | |
| 122 | |
| 123 | def train(model, optimizer, loader): |
| 124 | model.train() |
| 125 | |
| 126 | total_loss = 0 |
| 127 | for data in loader: |
| 128 | optimizer.zero_grad() |
| 129 | data = data.to(device) |
| 130 | out = model(data) |
| 131 | loss = F.nll_loss(out, data.y.view(-1)) |
| 132 | loss.backward() |
| 133 | total_loss += loss.item() * num_graphs(data) |
| 134 | optimizer.step() |
| 135 | return total_loss / len(loader.dataset) |
| 136 | |
| 137 | |
| 138 | def eval_acc(model, loader): |
Tested by
no test coverage detected