MCPcopy
hub / github.com/OpenPPL/ppq / build

Method build

ppq/quantization/algorithm/training.py:216–300  ·  view source on GitHub ↗

子图分割算法, 这个算法将从指定节点出发, 构造一个满足定义的子图结构 Solving best block from given operation. Block definition:(子图定义) A Block is a triple contains S, E, M, where S is the input node of block where E is the output node of block where M conta

(self, op: Operation, limit: int)

Source from the content-addressed store, hash-verified

214 return TrainableBlock(sp=sp, ep=ep, rps=[op for _, op in rps])
215
216 def build(self, op: Operation, limit: int) -> TrainableBlock:
217 """子图分割算法, 这个算法将从指定节点出发, 构造一个满足定义的子图结构 Solving best block from given
218 operation.
219
220 Block definition:(子图定义)
221 A Block is a triple contains S, E, M,
222 where S is the input node of block
223 where E is the output node of block
224 where M contains all nodes inside block
225
226 Property:(子图性质)
227 1. Minimal TrainableBlock start from p is {p, p, {p}},
228 this block have only one node as both input and output.
229 2. When S != E,
230 E must on every path from S to graph output.
231 S must on every path from graph input to E.
232 3. M contains and only contains nodes on all paths from S to E.
233
234 Lemma:(算法引理)
235 1. 如果 s 的后继节点只有一个 e, 且 e 的输入只有一个,那么 {s, e, {s, e}} 构成满足定义的子图,从 s 寻找子图的任务可以递归由 e 完成
236 2. 如果 s 的后继存在多个节点,则只存在两种情况:
237 2.1 从 s 出发,最大子图即为 {s, s, {s}}。
238 2.2 从 s 出发,可以构成子图 {s, e, R} (e!=s), 那么R中必须含有一个节点接收多个输入。(可用反证法证明,略)
239
240 Algorithm:(算法)
241 Build(s, d):
242 如果区块长度大于所需,则返回现有内容
243 从 s 出发,如果 s 的后继节点只有一个e,则判断e的输入节点个数:
244 1. 如果 e 是单输入节点,执行Build(e, d-1),并将其结果与 {s, e, {s, e}} 合并
245 2. 如果 e 是多输入节点,算法立即停机,返回 {s, s, {s}}
246
247 如果 s 的后继节点存在多个,找出距离 s 拓扑序最近的多输入的节点 k1,判断 s 到输出的路径是否能够被 k1 阻断
248 如果 k 成功阻断所有输出,执行Build(k1, d-1),并将其结果与 {s, k1, F(s, k1)} 合并
249 如果 k 不能阻断输出,寻找距离 s 次近的多输入节点 k2,重复判断
250 直到 kn 到达 s 的距离超出限制
251
252 函数 F(s, k1) 取出 从 s 到 k1 路上的所有节点
253
254 可利用引理证明算法正确性,从略
255 时间复杂度: O(kd) k 为节点最大度数 d 为深度限制。建立所有Block所需时间 O(nkd)
256 """
257 def _find_multi_input_ep(op: Operation):
258 # 如果当前节点后继节点存在多个,层序遍历寻找阻断节点
259 least_first_queue = PriorityQueue()
260 least_first_queue.push(self.depth[op], op)
261 least_first_queue.pop()
262
263 for down_op in self.graph.get_downstream_operations(op):
264 least_first_queue.push(self.depth[down_op], down_op)
265
266 while not least_first_queue.empty():
267 iter_operation = least_first_queue.pop()[-1]
268 if (least_first_queue.empty()):
269 upstream_ops = self.graph.get_upstream_operations(iter_operation)
270 if all([op in least_first_queue._ops for op in upstream_ops]) and len(upstream_ops) > 1:
271 return iter_operation
272 for down_op in self.graph.get_downstream_operations(iter_operation):
273 least_first_queue.push(self.depth[down_op], down_op)

Callers 2

optimizeMethod · 0.95

Calls 2

create_blockMethod · 0.95

Tested by

no test coverage detected