(self, root)
| 199 | self._trees.append(self._build_tree(sent)) |
| 200 | |
| 201 | def _build_tree(self, root): |
| 202 | g = nx.DiGraph() |
| 203 | |
| 204 | def _rec_build(nid, node): |
| 205 | for child in node: |
| 206 | cid = g.number_of_nodes() |
| 207 | if isinstance(child[0], str) or isinstance(child[0], bytes): |
| 208 | # leaf node |
| 209 | word = self.vocab.get(child[0].lower(), self.UNK_WORD) |
| 210 | g.add_node(cid, x=word, y=int(child.label()), mask=1) |
| 211 | else: |
| 212 | g.add_node( |
| 213 | cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0 |
| 214 | ) |
| 215 | _rec_build(cid, child) |
| 216 | g.add_edge(cid, nid) |
| 217 | |
| 218 | # add root |
| 219 | g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0) |
| 220 | _rec_build(0, root) |
| 221 | ret = from_networkx(g, node_attrs=["x", "y", "mask"]) |
| 222 | return ret |
| 223 | |
| 224 | @property |
| 225 | def graph_path(self): |
no test coverage detected