MCPcopy Index your code
hub / github.com/dmlc/dgl / train

Function train

examples/graphbolt/disk_based_feature/node_classification.py:191–235  ·  view source on GitHub ↗
(
    train_dataloader,
    valid_dataloader,
    model,
    gpu_cache_miss_rate_fn,
    cpu_cache_miss_rate_fn,
    device,
)

Source from the content-addressed store, hash-verified

189
190
191def train(
192 train_dataloader,
193 valid_dataloader,
194 model,
195 gpu_cache_miss_rate_fn,
196 cpu_cache_miss_rate_fn,
197 device,
198):
199 optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
200 loss_fn = nn.CrossEntropyLoss()
201
202 best_model = None
203 best_model_acc = 0
204 best_model_epoch = -1
205
206 for epoch in range(args.epochs):
207 train_loss, train_acc, duration = train_helper(
208 train_dataloader,
209 model,
210 optimizer,
211 loss_fn,
212 gpu_cache_miss_rate_fn,
213 cpu_cache_miss_rate_fn,
214 device,
215 )
216 val_acc = evaluate(
217 model,
218 valid_dataloader,
219 gpu_cache_miss_rate_fn,
220 cpu_cache_miss_rate_fn,
221 device,
222 )
223 if val_acc > best_model_acc:
224 best_model_acc = val_acc
225 best_model = deepcopy(model.state_dict())
226 best_model_epoch = epoch
227 print(
228 f"Epoch {epoch:02d}, Loss: {train_loss.item():.4f}, "
229 f"Approx. Train: {train_acc.item():.4f}, "
230 f"Approx. Val: {val_acc.item():.4f}, "
231 f"Time: {duration}s"
232 )
233 if best_model_epoch + args.early_stopping_patience < epoch:
234 break
235 return best_model
236
237
238@torch.no_grad()

Callers 1

mainFunction · 0.70

Calls 4

parametersMethod · 0.80
state_dictMethod · 0.80
train_helperFunction · 0.70
evaluateFunction · 0.70

Tested by

no test coverage detected