MCPcopy Index your code
hub / github.com/dmlc/dgl / train

Function train

examples/sparse/sampling/ladies.py:167–197  ·  view source on GitHub ↗
(device, A, ndata, dataset, model)

Source from the content-addressed store, hash-verified

165
166
167def train(device, A, ndata, dataset, model):
168 # Create sampler & dataloader.
169 train_idx = dataset.train_idx.to(device)
170 val_idx = dataset.val_idx.to(device)
171
172 train_dataloader = torch.utils.data.DataLoader(
173 train_idx, batch_size=1024, shuffle=True
174 )
175 val_dataloader = torch.utils.data.DataLoader(val_idx, batch_size=1024)
176
177 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
178
179 fanouts = [4000, 4000, 4000]
180 for epoch in range(20):
181 model.train()
182 total_loss = 0
183 for it, seeds in enumerate(train_dataloader):
184 sampled_matrices, x, y = multilayer_sample(A, fanouts, seeds, ndata)
185 y_hat = model(sampled_matrices, x)
186 loss = F.cross_entropy(y_hat, y)
187 optimizer.zero_grad()
188 loss.backward()
189 optimizer.step()
190 total_loss += loss.item()
191
192 acc = evaluate(model, A, val_dataloader, ndata, dataset.num_classes)
193 print(
194 "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
195 epoch, total_loss / (it + 1), acc.item()
196 )
197 )
198
199
200if __name__ == "__main__":

Callers 1

ladies.pyFile · 0.70

Calls 9

parametersMethod · 0.80
formatMethod · 0.80
multilayer_sampleFunction · 0.70
evaluateFunction · 0.70
toMethod · 0.45
trainMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected