MCPcopy
hub / github.com/THUDM/CogDL / pretrain

Function pretrain

examples/graphmae2/main_large.py:90–142  ·  view source on GitHub ↗
(model, feats, graph, ego_graph_nodes, max_epoch, device, use_scheduler, lr, weight_decay, batch_size=512, sampling_method="lc", optimizer="adam", drop_edge_rate=0)

Source from the content-addressed store, hash-verified

88
89
90def pretrain(model, feats, graph, ego_graph_nodes, max_epoch, device, use_scheduler, lr, weight_decay, batch_size=512, sampling_method="lc", optimizer="adam", drop_edge_rate=0):
91 logging.info("start training..")
92
93 model = model.to(device)
94 optimizer = create_optimizer(optimizer, model, lr, weight_decay)
95
96 dataloader = setup_training_dataloder(
97 sampling_method, ego_graph_nodes, graph, feats, batch_size=batch_size, drop_edge_rate=drop_edge_rate)
98
99 logging.info(f"After creating dataloader: Memory: {show_occupied_memory():.2f} MB")
100 if use_scheduler and max_epoch > 0:
101 logging.info("Use scheduler")
102 scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
103 scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
104 else:
105 scheduler = None
106
107 for epoch in range(max_epoch):
108 epoch_iter = tqdm(dataloader)
109 losses = []
110 # assert (graph.in_degrees() > 0).all(), "after loading"
111
112 for batch_g in epoch_iter:
113 model.train()
114 if drop_edge_rate > 0:
115 batch_g, targets, _, node_idx, drop_g1, drop_g2 = batch_g
116 batch_g = batch_g.to(device)
117 drop_g1 = drop_g1.to(device)
118 drop_g2 = drop_g2.to(device)
119 x = batch_g.x
120 loss = model(batch_g, x, targets, epoch, drop_g1, drop_g2)
121 else:
122 batch_g, targets, _, node_idx = batch_g
123 batch_g = batch_g.to(device)
124 x = batch_g.x
125 loss = model(batch_g, x, targets, epoch)
126
127 optimizer.zero_grad()
128 loss.backward()
129 torch.nn.utils.clip_grad_norm_(model.parameters(), 3)
130 optimizer.step()
131
132 epoch_iter.set_description(f"train_loss: {loss.item():.4f}, Memory: {show_occupied_memory():.2f} MB")
133 losses.append(loss.item())
134
135 if scheduler is not None:
136 scheduler.step()
137
138 torch.save(model.state_dict(), os.path.join(model_dir, model_name))
139
140 print(f"# Epoch {epoch} | train_loss: {np.mean(losses):.4f}, Memory: {show_occupied_memory():.2f} MB")
141
142 return model
143
144
145if __name__ == "__main__":

Callers 1

main_large.pyFile · 0.70

Calls 8

create_optimizerFunction · 0.90
setup_training_dataloderFunction · 0.90
show_occupied_memoryFunction · 0.90
toMethod · 0.45
trainMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected