Train and evaluate DistSAGE. Parameters ---------- args : argparse.Args Arguments for train and evaluate. device : torch.Device Target device for train and evaluate. data : Packed Data Packed data includes train/val/test IDs, feature dimension,
(args, device, data)
| 195 | |
| 196 | |
| 197 | def run(args, device, data): |
| 198 | """ |
| 199 | Train and evaluate DistSAGE. |
| 200 | |
| 201 | Parameters |
| 202 | ---------- |
| 203 | args : argparse.Args |
| 204 | Arguments for train and evaluate. |
| 205 | device : torch.Device |
| 206 | Target device for train and evaluate. |
| 207 | data : Packed Data |
| 208 | Packed data includes train/val/test IDs, feature dimension, |
| 209 | number of classes, graph. |
| 210 | """ |
| 211 | train_nid, val_nid, test_nid, in_feats, n_classes, g = data |
| 212 | sampler = dgl.dataloading.NeighborSampler( |
| 213 | [int(fanout) for fanout in args.fan_out.split(",")] |
| 214 | ) |
| 215 | dataloader = dgl.distributed.DistNodeDataLoader( |
| 216 | g, |
| 217 | train_nid, |
| 218 | sampler, |
| 219 | batch_size=args.batch_size, |
| 220 | shuffle=True, |
| 221 | drop_last=False, |
| 222 | ) |
| 223 | model = DistSAGE( |
| 224 | in_feats, |
| 225 | args.num_hidden, |
| 226 | n_classes, |
| 227 | args.num_layers, |
| 228 | F.relu, |
| 229 | args.dropout, |
| 230 | ) |
| 231 | model = model.to(device) |
| 232 | if args.num_gpus == 0: |
| 233 | model = th.nn.parallel.DistributedDataParallel(model) |
| 234 | else: |
| 235 | model = th.nn.parallel.DistributedDataParallel( |
| 236 | model, device_ids=[device], output_device=device |
| 237 | ) |
| 238 | loss_fcn = nn.CrossEntropyLoss() |
| 239 | loss_fcn = loss_fcn.to(device) |
| 240 | optimizer = optim.Adam(model.parameters(), lr=args.lr) |
| 241 | |
| 242 | # Training loop. |
| 243 | iter_tput = [] |
| 244 | epoch = 0 |
| 245 | epoch_time = [] |
| 246 | test_acc = 0.0 |
| 247 | for _ in range(args.num_epochs): |
| 248 | epoch += 1 |
| 249 | tic = time.time() |
| 250 | # Various time statistics. |
| 251 | sample_time = 0 |
| 252 | forward_time = 0 |
| 253 | backward_time = 0 |
| 254 | update_time = 0 |
no test coverage detected