MCPcopy
hub / github.com/dmlc/dgl / run

Function run

examples/distributed/graphsage/node_classification.py:197–334  ·  view source on GitHub ↗

Train and evaluate DistSAGE. Parameters ---------- args : argparse.Args Arguments for train and evaluate. device : torch.Device Target device for train and evaluate. data : Packed Data Packed data includes train/val/test IDs, feature dimension,

(args, device, data)

Source from the content-addressed store, hash-verified

195
196
197def run(args, device, data):
198 """
199 Train and evaluate DistSAGE.
200
201 Parameters
202 ----------
203 args : argparse.Args
204 Arguments for train and evaluate.
205 device : torch.Device
206 Target device for train and evaluate.
207 data : Packed Data
208 Packed data includes train/val/test IDs, feature dimension,
209 number of classes, graph.
210 """
211 train_nid, val_nid, test_nid, in_feats, n_classes, g = data
212 sampler = dgl.dataloading.NeighborSampler(
213 [int(fanout) for fanout in args.fan_out.split(",")]
214 )
215 dataloader = dgl.distributed.DistNodeDataLoader(
216 g,
217 train_nid,
218 sampler,
219 batch_size=args.batch_size,
220 shuffle=True,
221 drop_last=False,
222 )
223 model = DistSAGE(
224 in_feats,
225 args.num_hidden,
226 n_classes,
227 args.num_layers,
228 F.relu,
229 args.dropout,
230 )
231 model = model.to(device)
232 if args.num_gpus == 0:
233 model = th.nn.parallel.DistributedDataParallel(model)
234 else:
235 model = th.nn.parallel.DistributedDataParallel(
236 model, device_ids=[device], output_device=device
237 )
238 loss_fcn = nn.CrossEntropyLoss()
239 loss_fcn = loss_fcn.to(device)
240 optimizer = optim.Adam(model.parameters(), lr=args.lr)
241
242 # Training loop.
243 iter_tput = []
244 epoch = 0
245 epoch_time = []
246 test_acc = 0.0
247 for _ in range(args.num_epochs):
248 epoch += 1
249 tic = time.time()
250 # Various time statistics.
251 sample_time = 0
252 forward_time = 0
253 backward_time = 0
254 update_time = 0

Callers 1

mainFunction · 0.70

Calls 12

parametersMethod · 0.80
appendMethod · 0.80
DistSAGEClass · 0.70
compute_accFunction · 0.70
evaluateFunction · 0.70
toMethod · 0.45
joinMethod · 0.45
longMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45
rankMethod · 0.45

Tested by

no test coverage detected