| 286 | |
| 287 | |
| 288 | class PatternMatchHelper: |
| 289 | @ staticmethod |
| 290 | def match_burte_force( |
| 291 | graph: BaseGraph, pattern: GraphPattern, |
| 292 | exclusive: bool, max_candidates: int = 1000000) -> List[List[Operation]]: |
| 293 | """暴力子图模式匹配 这是 PPQ 0.6.6 更新的内容 |
| 294 | 在 0.6.6 之前,我们使用具有不确定性的贪心匹配算法,但是考虑到实际应用中的问题 |
| 295 | 在 0.6.6 版本之后,我们将其修改为枚举匹配。 |
| 296 | |
| 297 | 子图匹配问题是一个 NP-Hard 的问题,不存在多项式时间复杂度的解法。 |
| 298 | 你需要给出一个模式子图,match_burte_force 方法将在 graph 中对模式子图进行匹配。 |
| 299 | |
| 300 | PPQ 使用了非递归的算法完成上述匹配,其最坏时间和空间复杂度大概都是 O(NM^k) |
| 301 | 其中 N 是母图节点个数,M 是子图节点个数,k 是母图的最大出度 |
| 302 | |
| 303 | 对于存在二义性子图模式,匹配复杂度将指数级增长;为了限制算法执行时间,当匹配到多于 |
| 304 | max_candidates 个模式子图时,算法强制停机,并报错返回。 |
| 305 | |
| 306 | 实际使用中的时间复杂度更加接近于 O(NM) |
| 307 | |
| 308 | 参数 exclusive 指定了是否需要进行精确匹配。在精确匹配模式下: |
| 309 | 1. 不允许模式子图中除根节点外的其他节点有来自模式子图以外节点的输入 |
| 310 | 2. 不允许模式子图中除叶节点外的其他节点有朝向模式子图以外节点的输出 |
| 311 | |
| 312 | Example: |
| 313 | pt = PatternTree( |
| 314 | patterns = [lambda x: x.is_computing_op, 'Softplus', 'Tanh', 'Mul'] |
| 315 | edges = [[0, 1], [1, 2], [2, 3], [0, 3]]) |
| 316 | |
| 317 | pt create an abstract tree pattern of: |
| 318 | --- 'Softplus' --- 'Tanh' -- |
| 319 | lambda x: x.is_computing_op --- + + --- 'Mul' |
| 320 | --- --- --- --- -- |
| 321 | |
| 322 | """ |
| 323 | |
| 324 | def is_linked(upstream_op: Operation, downstream_op: Operation) -> bool: |
| 325 | if upstream_op is None or downstream_op is None: return True |
| 326 | return downstream_op in graph.get_downstream_operations(upstream_op) |
| 327 | |
| 328 | node_order = pattern.order |
| 329 | matched_patterns = [] |
| 330 | |
| 331 | # match root from graph, further pattern matching will start from root. |
| 332 | for operation in graph.operations.values(): |
| 333 | root_idx = node_order[0] |
| 334 | if pattern.node_patterns[root_idx](operation): |
| 335 | matched_patterns.append([operation] + [None for _ in range(len(node_order) - 1)]) |
| 336 | |
| 337 | for idx in node_order[1: ]: |
| 338 | node_candidates, next_generation = [], [] |
| 339 | for operation in graph.operations.values(): |
| 340 | if pattern.node_patterns[idx](operation): |
| 341 | node_candidates.append(operation) |
| 342 | |
| 343 | for matched_pattern in matched_patterns: |
| 344 | for operation in node_candidates: |
| 345 | is_pattern_root = len(pattern.input_table[idx]) == 0 |