MCPcopy
hub / github.com/dmlc/dgl / merge_graphs

Function merge_graphs

python/dgl/distributed/graph_services.py:692–727  ·  view source on GitHub ↗

Merge request from multiple servers

(res_list, num_nodes, exclude_edges=None)

Source from the content-addressed store, hash-verified

690
691
692def merge_graphs(res_list, num_nodes, exclude_edges=None):
693 """Merge request from multiple servers"""
694 if len(res_list) > 1:
695 srcs = []
696 dsts = []
697 eids = []
698 etype_ids = []
699 for res in res_list:
700 srcs.append(res.global_src)
701 dsts.append(res.global_dst)
702 eids.append(res.global_eids)
703 etype_ids.append(res.etype_ids)
704 src_tensor = F.cat(srcs, 0)
705 dst_tensor = F.cat(dsts, 0)
706 eid_tensor = None if eids[0] is None else F.cat(eids, 0)
707 etype_id_tensor = None if etype_ids[0] is None else F.cat(etype_ids, 0)
708 else:
709 src_tensor = res_list[0].global_src
710 dst_tensor = res_list[0].global_dst
711 eid_tensor = res_list[0].global_eids
712 etype_id_tensor = res_list[0].etype_ids
713 if exclude_edges is not None:
714 mask = torch.isin(
715 eid_tensor, exclude_edges, assume_unique=True, invert=True
716 )
717 src_tensor = src_tensor[mask]
718 dst_tensor = dst_tensor[mask]
719 eid_tensor = eid_tensor[mask]
720 if etype_id_tensor is not None:
721 etype_id_tensor = etype_id_tensor[mask]
722 g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
723 if eid_tensor is not None:
724 g.edata[EID] = eid_tensor
725 if etype_id_tensor is not None:
726 g.edata[ETYPE] = etype_id_tensor
727 return g
728
729
730LocalSampledGraph = namedtuple( # pylint: disable=unexpected-keyword-arg

Callers 1

_distributed_accessFunction · 0.85

Calls 2

graphFunction · 0.85
appendMethod · 0.80

Tested by

no test coverage detected