子图分割算法, 这个算法将从指定节点出发, 构造一个满足定义的子图结构 Solving best block from given operation. Block definition:(子图定义) A Block is a triple contains S, E, M, where S is the input node of block where E is the output node of block where M conta
(self, op: Operation, limit: int)
| 214 | return TrainableBlock(sp=sp, ep=ep, rps=[op for _, op in rps]) |
| 215 | |
| 216 | def build(self, op: Operation, limit: int) -> TrainableBlock: |
| 217 | """子图分割算法, 这个算法将从指定节点出发, 构造一个满足定义的子图结构 Solving best block from given |
| 218 | operation. |
| 219 | |
| 220 | Block definition:(子图定义) |
| 221 | A Block is a triple contains S, E, M, |
| 222 | where S is the input node of block |
| 223 | where E is the output node of block |
| 224 | where M contains all nodes inside block |
| 225 | |
| 226 | Property:(子图性质) |
| 227 | 1. Minimal TrainableBlock start from p is {p, p, {p}}, |
| 228 | this block have only one node as both input and output. |
| 229 | 2. When S != E, |
| 230 | E must on every path from S to graph output. |
| 231 | S must on every path from graph input to E. |
| 232 | 3. M contains and only contains nodes on all paths from S to E. |
| 233 | |
| 234 | Lemma:(算法引理) |
| 235 | 1. 如果 s 的后继节点只有一个 e, 且 e 的输入只有一个,那么 {s, e, {s, e}} 构成满足定义的子图,从 s 寻找子图的任务可以递归由 e 完成 |
| 236 | 2. 如果 s 的后继存在多个节点,则只存在两种情况: |
| 237 | 2.1 从 s 出发,最大子图即为 {s, s, {s}}。 |
| 238 | 2.2 从 s 出发,可以构成子图 {s, e, R} (e!=s), 那么R中必须含有一个节点接收多个输入。(可用反证法证明,略) |
| 239 | |
| 240 | Algorithm:(算法) |
| 241 | Build(s, d): |
| 242 | 如果区块长度大于所需,则返回现有内容 |
| 243 | 从 s 出发,如果 s 的后继节点只有一个e,则判断e的输入节点个数: |
| 244 | 1. 如果 e 是单输入节点,执行Build(e, d-1),并将其结果与 {s, e, {s, e}} 合并 |
| 245 | 2. 如果 e 是多输入节点,算法立即停机,返回 {s, s, {s}} |
| 246 | |
| 247 | 如果 s 的后继节点存在多个,找出距离 s 拓扑序最近的多输入的节点 k1,判断 s 到输出的路径是否能够被 k1 阻断 |
| 248 | 如果 k 成功阻断所有输出,执行Build(k1, d-1),并将其结果与 {s, k1, F(s, k1)} 合并 |
| 249 | 如果 k 不能阻断输出,寻找距离 s 次近的多输入节点 k2,重复判断 |
| 250 | 直到 kn 到达 s 的距离超出限制 |
| 251 | |
| 252 | 函数 F(s, k1) 取出 从 s 到 k1 路上的所有节点 |
| 253 | |
| 254 | 可利用引理证明算法正确性,从略 |
| 255 | 时间复杂度: O(kd) k 为节点最大度数 d 为深度限制。建立所有Block所需时间 O(nkd) |
| 256 | """ |
| 257 | def _find_multi_input_ep(op: Operation): |
| 258 | # 如果当前节点后继节点存在多个,层序遍历寻找阻断节点 |
| 259 | least_first_queue = PriorityQueue() |
| 260 | least_first_queue.push(self.depth[op], op) |
| 261 | least_first_queue.pop() |
| 262 | |
| 263 | for down_op in self.graph.get_downstream_operations(op): |
| 264 | least_first_queue.push(self.depth[down_op], down_op) |
| 265 | |
| 266 | while not least_first_queue.empty(): |
| 267 | iter_operation = least_first_queue.pop()[-1] |
| 268 | if (least_first_queue.empty()): |
| 269 | upstream_ops = self.graph.get_upstream_operations(iter_operation) |
| 270 | if all([op in least_first_queue._ops for op in upstream_ops]) and len(upstream_ops) > 1: |
| 271 | return iter_operation |
| 272 | for down_op in self.graph.get_downstream_operations(iter_operation): |
| 273 | least_first_queue.push(self.depth[down_op], down_op) |
no test coverage detected