| 317 | |
| 318 | |
| 319 | def lp_gather_features(pred, target, world_size=1, use_horovod=False): |
| 320 | if use_horovod: |
| 321 | assert hvd is not None, "Please install horovod" |
| 322 | with torch.no_grad(): |
| 323 | all_preds = hvd.allgather(pred) |
| 324 | all_targets = hvd.allgath(target) |
| 325 | else: |
| 326 | gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] |
| 327 | gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] |
| 328 | |
| 329 | dist.all_gather(gathered_preds, pred) |
| 330 | dist.all_gather(gathered_targets, target) |
| 331 | all_preds = torch.cat(gathered_preds, dim=0) |
| 332 | all_targets = torch.cat(gathered_targets, dim=0) |
| 333 | |
| 334 | return all_preds, all_targets |
| 335 | |
| 336 | |
| 337 | def get_map(pred, target): |