MCPcopy Index your code
hub / github.com/dmlc/dgl / train

Function train

examples/pytorch/diffpool/train.py:242–323  ·  view source on GitHub ↗

training function

(dataset, model, prog_args, same_feat=True, val_dataset=None)

Source from the content-addressed store, hash-verified

240
241
242def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
243 """
244 training function
245 """
246 dir = prog_args.save_dir + "/" + prog_args.dataset
247 if not os.path.exists(dir):
248 os.makedirs(dir)
249 dataloader = dataset
250 optimizer = torch.optim.Adam(
251 filter(lambda p: p.requires_grad, model.parameters()), lr=0.001
252 )
253 early_stopping_logger = {"best_epoch": -1, "val_acc": -1}
254
255 if prog_args.cuda > 0:
256 torch.cuda.set_device(0)
257 for epoch in range(prog_args.epoch):
258 begin_time = time.time()
259 model.train()
260 accum_correct = 0
261 total = 0
262 print("\nEPOCH ###### {} ######".format(epoch))
263 computation_time = 0.0
264 for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
265 for key, value in batch_graph.ndata.items():
266 batch_graph.ndata[key] = value.float()
267 graph_labels = graph_labels.long()
268 if torch.cuda.is_available():
269 batch_graph = batch_graph.to(torch.cuda.current_device())
270 graph_labels = graph_labels.cuda()
271
272 model.zero_grad()
273 compute_start = time.time()
274 ypred = model(batch_graph)
275 indi = torch.argmax(ypred, dim=1)
276 correct = torch.sum(indi == graph_labels).item()
277 accum_correct += correct
278 total += graph_labels.size()[0]
279 loss = model.loss(ypred, graph_labels)
280 loss.backward()
281 batch_compute_time = time.time() - compute_start
282 computation_time += batch_compute_time
283 nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
284 optimizer.step()
285
286 train_accu = accum_correct / total
287 print(
288 "train accuracy for this epoch {} is {:.2f}%".format(
289 epoch, train_accu * 100
290 )
291 )
292 elapsed_time = time.time() - begin_time
293 print(
294 "loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
295 loss.item(), elapsed_time, computation_time
296 )
297 )
298 global_train_time_per_epoch.append(elapsed_time)
299 if val_dataset is not None:

Callers 1

graph_classify_taskFunction · 0.70

Calls 15

filterFunction · 0.85
parametersMethod · 0.80
formatMethod · 0.80
cudaMethod · 0.80
appendMethod · 0.80
state_dictMethod · 0.80
evaluateFunction · 0.70
set_deviceMethod · 0.45
trainMethod · 0.45
itemsMethod · 0.45
floatMethod · 0.45
longMethod · 0.45

Tested by

no test coverage detected