| 164 | |
| 165 | |
| 166 | class GraphPattern(): |
| 167 | |
| 168 | def __init__(self, node_patterns: List[Callable], edges: List[List[int]]) -> None: |
| 169 | """Pattern Tree 是一个用来表示图模式的结构体 这将在图中检索任意一个子图. |
| 170 | |
| 171 | 你将使用 Graph Pattern 定义你的子图结构 |
| 172 | 使用 patterns 确定每一个节点需要满足的条件 |
| 173 | 使用 edges 将节点们彼此相连从而构成图结构 |
| 174 | |
| 175 | 构成的图必须可以进行拓扑排序,不可以检索有环结构,不可以检索不连通结构 |
| 176 | 例子: |
| 177 | pattern = ['Conv', 'Conv', 'Conv'], |
| 178 | edges = [[0, 1], [1, 2], [0, 2]] |
| 179 | |
| 180 | 描述了一个类似这样的树形结构: |
| 181 | |
| 182 | Conv -+- Conv -+- Conv |
| 183 | | | |
| 184 | ---------- |
| 185 | |
| 186 | 第二个例子: |
| 187 | pt = PatternTree( |
| 188 | patterns = [lambda x: x.is_computing_op, 'Softplus', 'Tanh', 'Mul'] |
| 189 | edges = [[0, 1], [1, 2], [2, 3], [0, 3]]) |
| 190 | |
| 191 | pt create an abstract tree pattern of: |
| 192 | --- 'Softplus' --- 'Tanh' -- |
| 193 | lambda x: x.is_computing_op --- + + --- 'Mul' |
| 194 | --- --- --- --- -- |
| 195 | |
| 196 | 错误的例子: |
| 197 | pattern = ['Conv', 'Conv', 'Conv'], |
| 198 | edges = [[0, 1], [1, 2], [2, 0]] |
| 199 | 因为图中存在循环结构而无法检索 |
| 200 | """ |
| 201 | for idx, node_pattern in enumerate(node_patterns): |
| 202 | if isinstance(node_pattern, str): |
| 203 | node_patterns[idx] = TypeExpr(node_pattern) |
| 204 | elif not isinstance(node_pattern, Callable): |
| 205 | raise TypeError(f'Can not create Pattern with node pattern {str(node_pattern)} it is not callable.') |
| 206 | |
| 207 | for edge in edges: |
| 208 | if not isinstance(edge, tuple) and not isinstance(edge, list): |
| 209 | raise TypeError(f'Can not create Pattern with edge {str(edge)} it is not tuple or list.') |
| 210 | if len(edge) != 2: |
| 211 | raise ValueError(f'Can not create Pattern with edge {str(edge)} ' |
| 212 | f'it should contains exact 2 elements, however {len(edge)} was given.') |
| 213 | sp, ep = edge |
| 214 | if not isinstance(sp, int) or not isinstance(ep, int): |
| 215 | raise TypeError(f'Can not create Pattern was given edge {[str(sp), str(ep)]}, ' |
| 216 | 'expect int value here.') |
| 217 | |
| 218 | self.order, self.output_table, self.input_table = self.compile(node_patterns=node_patterns, edges=edges) |
| 219 | self.argsort_order = sorted([(_, idx) for idx, _ in enumerate(self.order)]) |
| 220 | self.argsort_order = [idx for _, idx in self.argsort_order] |
| 221 | self.node_patterns = node_patterns |
| 222 | |
| 223 | |