MCPcopy
hub / github.com/dmlc/dgl / run

Function run

examples/distributed/graphsage/node_classification_unsupervised.py:186–369  ·  view source on GitHub ↗
(args, device, data)

Source from the content-addressed store, hash-verified

184
185
186def run(args, device, data):
187 # Unpack data
188 (
189 train_eids,
190 train_nids,
191 in_feats,
192 g,
193 global_train_nid,
194 global_valid_nid,
195 global_test_nid,
196 labels,
197 ) = data
198 # Create sampler
199 neg_sampler = dgl.dataloading.negative_sampler.Uniform(args.num_negs)
200 sampler = dgl.dataloading.NeighborSampler(
201 [int(fanout) for fanout in args.fan_out.split(",")]
202 )
203 # Create dataloader
204 exclude = "reverse_id" if args.remove_edge else None
205 reverse_eids = th.arange(g.num_edges()) if args.remove_edge else None
206 dataloader = dgl.distributed.DistEdgeDataLoader(
207 g,
208 train_eids,
209 sampler,
210 negative_sampler=neg_sampler,
211 exclude=exclude,
212 reverse_eids=reverse_eids,
213 batch_size=args.batch_size,
214 shuffle=True,
215 drop_last=False,
216 )
217 # Define model and optimizer
218 model = DistSAGE(
219 in_feats,
220 args.num_hidden,
221 args.num_hidden,
222 args.num_layers,
223 F.relu,
224 args.dropout,
225 )
226 model = model.to(device)
227 if not args.standalone:
228 if args.num_gpus == -1:
229 model = th.nn.parallel.DistributedDataParallel(model)
230 else:
231 dev_id = g.rank() % args.num_gpus
232 model = th.nn.parallel.DistributedDataParallel(
233 model, device_ids=[dev_id], output_device=dev_id
234 )
235 loss_fcn = CrossEntropyLoss()
236 loss_fcn = loss_fcn.to(device)
237 optimizer = optim.Adam(model.parameters(), lr=args.lr)
238
239 # Training loop
240 epoch = 0
241 for epoch in range(args.num_epochs):
242 num_seeds = 0
243 num_inputs = 0

Callers 1

mainFunction · 0.70

Calls 15

generate_embFunction · 0.85
parametersMethod · 0.80
appendMethod · 0.80
formatMethod · 0.80
DistSAGEClass · 0.70
CrossEntropyLossClass · 0.70
load_subtensorFunction · 0.70
compute_accFunction · 0.70
num_edgesMethod · 0.45
toMethod · 0.45
rankMethod · 0.45
joinMethod · 0.45

Tested by

no test coverage detected