MCPcopy Index your code
hub / github.com/OpenPPL/ppq / GraphPattern

Class GraphPattern

ppq/IR/search.py:166–276  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

164
165
166class GraphPattern():
167
168 def __init__(self, node_patterns: List[Callable], edges: List[List[int]]) -> None:
169 """Pattern Tree 是一个用来表示图模式的结构体 这将在图中检索任意一个子图.
170
171 你将使用 Graph Pattern 定义你的子图结构
172 使用 patterns 确定每一个节点需要满足的条件
173 使用 edges 将节点们彼此相连从而构成图结构
174
175 构成的图必须可以进行拓扑排序,不可以检索有环结构,不可以检索不连通结构
176 例子:
177 pattern = ['Conv', 'Conv', 'Conv'],
178 edges = [[0, 1], [1, 2], [0, 2]]
179
180 描述了一个类似这样的树形结构:
181
182 Conv -+- Conv -+- Conv
183 | |
184 ----------
185
186 第二个例子:
187 pt = PatternTree(
188 patterns = [lambda x: x.is_computing_op, 'Softplus', 'Tanh', 'Mul']
189 edges = [[0, 1], [1, 2], [2, 3], [0, 3]])
190
191 pt create an abstract tree pattern of:
192 --- 'Softplus' --- 'Tanh' --
193 lambda x: x.is_computing_op --- + + --- 'Mul'
194 --- --- --- --- --
195
196 错误的例子:
197 pattern = ['Conv', 'Conv', 'Conv'],
198 edges = [[0, 1], [1, 2], [2, 0]]
199 因为图中存在循环结构而无法检索
200 """
201 for idx, node_pattern in enumerate(node_patterns):
202 if isinstance(node_pattern, str):
203 node_patterns[idx] = TypeExpr(node_pattern)
204 elif not isinstance(node_pattern, Callable):
205 raise TypeError(f'Can not create Pattern with node pattern {str(node_pattern)} it is not callable.')
206
207 for edge in edges:
208 if not isinstance(edge, tuple) and not isinstance(edge, list):
209 raise TypeError(f'Can not create Pattern with edge {str(edge)} it is not tuple or list.')
210 if len(edge) != 2:
211 raise ValueError(f'Can not create Pattern with edge {str(edge)} '
212 f'it should contains exact 2 elements, however {len(edge)} was given.')
213 sp, ep = edge
214 if not isinstance(sp, int) or not isinstance(ep, int):
215 raise TypeError(f'Can not create Pattern was given edge {[str(sp), str(ep)]}, '
216 'expect int value here.')
217
218 self.order, self.output_table, self.input_table = self.compile(node_patterns=node_patterns, edges=edges)
219 self.argsort_order = sorted([(_, idx) for idx, _ in enumerate(self.order)])
220 self.argsort_order = [idx for _, idx in self.argsort_order]
221 self.node_patterns = node_patterns
222
223

Callers 1

pattern_matchingMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected