MCPcopy Index your code
hub / github.com/dmlc/dgl / train

Function train

examples/sparse/appnp.py:56–84  ·  view source on GitHub ↗
(model, g, A_hat, X)

Source from the content-addressed store, hash-verified

54
55
56def train(model, g, A_hat, X):
57 label = g.ndata["label"]
58 train_mask = g.ndata["train_mask"]
59 optimizer = Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
60
61 for epoch in range(50):
62 # Forward.
63 model.train()
64 logits = model(A_hat, X)
65
66 # Compute loss with nodes in training set.
67 loss = F.cross_entropy(logits[train_mask], label[train_mask])
68
69 # Backward.
70 optimizer.zero_grad()
71 loss.backward()
72 optimizer.step()
73
74 # Compute prediction.
75 model.eval()
76 logits = model(A_hat, X)
77 pred = logits.argmax(dim=1)
78
79 # Evaluate the prediction.
80 val_acc, test_acc = evaluate(g, pred)
81 print(
82 f"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test"
83 f" acc: {test_acc:.3f}"
84 )
85
86
87if __name__ == "__main__":

Callers 1

appnp.pyFile · 0.70

Calls 6

parametersMethod · 0.80
evaluateFunction · 0.70
trainMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected