MCPcopy Index your code
hub / github.com/dmlc/dgl / main

Function main

examples/sparse/hypergraphatt.py:115–135  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

113
114
115def main(args):
116 H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data()
117 model = Net(X.shape[1], num_classes)
118 optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
119
120 with tqdm.trange(args.epochs) as tq:
121 for epoch in tq:
122 loss = train(model, optimizer, H, X, Y, train_mask)
123 val_acc, test_acc = evaluate(
124 model, H, X, Y, val_mask, test_mask, num_classes
125 )
126 tq.set_postfix(
127 {
128 "Loss": f"{loss:.5f}",
129 "Val acc": f"{val_acc:.5f}",
130 "Test acc": f"{test_acc:.5f}",
131 },
132 refresh=False,
133 )
134
135 print(f"Test acc: {test_acc:.3f}")
136
137
138if __name__ == "__main__":

Callers 1

hypergraphatt.pyFile · 0.70

Calls 5

parametersMethod · 0.80
load_dataFunction · 0.70
NetClass · 0.70
trainFunction · 0.70
evaluateFunction · 0.70

Tested by

no test coverage detected