Init internal states.
(self, gidx, ntypes, etypes, node_frames, edge_frames)
| 124 | self._init(gidx, ntypes, etypes, node_frames, edge_frames) |
| 125 | |
| 126 | def _init(self, gidx, ntypes, etypes, node_frames, edge_frames): |
| 127 | """Init internal states.""" |
| 128 | self._graph = gidx |
| 129 | self._canonical_etypes = None |
| 130 | self._batch_num_nodes = None |
| 131 | self._batch_num_edges = None |
| 132 | |
| 133 | # Handle node types |
| 134 | if isinstance(ntypes, tuple): |
| 135 | if len(ntypes) != 2: |
| 136 | errmsg = "Invalid input. Expect a pair (srctypes, dsttypes) but got {}".format( |
| 137 | ntypes |
| 138 | ) |
| 139 | raise TypeError(errmsg) |
| 140 | if not self._graph.is_metagraph_unibipartite(): |
| 141 | raise ValueError( |
| 142 | "Invalid input. The metagraph must be a uni-directional" |
| 143 | " bipartite graph." |
| 144 | ) |
| 145 | self._ntypes = ntypes[0] + ntypes[1] |
| 146 | self._srctypes_invmap = {t: i for i, t in enumerate(ntypes[0])} |
| 147 | self._dsttypes_invmap = { |
| 148 | t: i + len(ntypes[0]) for i, t in enumerate(ntypes[1]) |
| 149 | } |
| 150 | self._is_unibipartite = True |
| 151 | if len(ntypes[0]) == 1 and len(ntypes[1]) == 1 and len(etypes) == 1: |
| 152 | self._canonical_etypes = [ |
| 153 | (ntypes[0][0], etypes[0], ntypes[1][0]) |
| 154 | ] |
| 155 | else: |
| 156 | self._ntypes = ntypes |
| 157 | if len(ntypes) == 1: |
| 158 | src_dst_map = None |
| 159 | else: |
| 160 | src_dst_map = find_src_dst_ntypes( |
| 161 | self._ntypes, self._graph.metagraph |
| 162 | ) |
| 163 | self._is_unibipartite = src_dst_map is not None |
| 164 | if self._is_unibipartite: |
| 165 | self._srctypes_invmap, self._dsttypes_invmap = src_dst_map |
| 166 | else: |
| 167 | self._srctypes_invmap = { |
| 168 | t: i for i, t in enumerate(self._ntypes) |
| 169 | } |
| 170 | self._dsttypes_invmap = self._srctypes_invmap |
| 171 | |
| 172 | # Handle edge types |
| 173 | self._etypes = etypes |
| 174 | if self._canonical_etypes is None: |
| 175 | if len(etypes) == 1 and len(ntypes) == 1: |
| 176 | self._canonical_etypes = [(ntypes[0], etypes[0], ntypes[0])] |
| 177 | else: |
| 178 | self._canonical_etypes = make_canonical_etypes( |
| 179 | self._etypes, self._ntypes, self._graph.metagraph |
| 180 | ) |
| 181 | |
| 182 | # An internal map from etype to canonical etype tuple. |
| 183 | # If two etypes have the same name, an empty tuple is stored instead to indicate |
no test coverage detected