Merge request from multiple servers
(res_list, num_nodes, exclude_edges=None)
| 690 | |
| 691 | |
| 692 | def merge_graphs(res_list, num_nodes, exclude_edges=None): |
| 693 | """Merge request from multiple servers""" |
| 694 | if len(res_list) > 1: |
| 695 | srcs = [] |
| 696 | dsts = [] |
| 697 | eids = [] |
| 698 | etype_ids = [] |
| 699 | for res in res_list: |
| 700 | srcs.append(res.global_src) |
| 701 | dsts.append(res.global_dst) |
| 702 | eids.append(res.global_eids) |
| 703 | etype_ids.append(res.etype_ids) |
| 704 | src_tensor = F.cat(srcs, 0) |
| 705 | dst_tensor = F.cat(dsts, 0) |
| 706 | eid_tensor = None if eids[0] is None else F.cat(eids, 0) |
| 707 | etype_id_tensor = None if etype_ids[0] is None else F.cat(etype_ids, 0) |
| 708 | else: |
| 709 | src_tensor = res_list[0].global_src |
| 710 | dst_tensor = res_list[0].global_dst |
| 711 | eid_tensor = res_list[0].global_eids |
| 712 | etype_id_tensor = res_list[0].etype_ids |
| 713 | if exclude_edges is not None: |
| 714 | mask = torch.isin( |
| 715 | eid_tensor, exclude_edges, assume_unique=True, invert=True |
| 716 | ) |
| 717 | src_tensor = src_tensor[mask] |
| 718 | dst_tensor = dst_tensor[mask] |
| 719 | eid_tensor = eid_tensor[mask] |
| 720 | if etype_id_tensor is not None: |
| 721 | etype_id_tensor = etype_id_tensor[mask] |
| 722 | g = graph((src_tensor, dst_tensor), num_nodes=num_nodes) |
| 723 | if eid_tensor is not None: |
| 724 | g.edata[EID] = eid_tensor |
| 725 | if etype_id_tensor is not None: |
| 726 | g.edata[ETYPE] = etype_id_tensor |
| 727 | return g |
| 728 | |
| 729 | |
| 730 | LocalSampledGraph = namedtuple( # pylint: disable=unexpected-keyword-arg |
no test coverage detected