Args: grad_list: K x N x 2 Returns: K x N: gradients K x N: variables
(grad_list)
| 116 | |
| 117 | |
| 118 | def split_grad_list(grad_list): |
| 119 | """ |
| 120 | Args: |
| 121 | grad_list: K x N x 2 |
| 122 | |
| 123 | Returns: |
| 124 | K x N: gradients |
| 125 | K x N: variables |
| 126 | """ |
| 127 | g = [] |
| 128 | v = [] |
| 129 | for tower in grad_list: |
| 130 | g.append([x[0] for x in tower]) |
| 131 | v.append([x[1] for x in tower]) |
| 132 | return g, v |
| 133 | |
| 134 | |
| 135 | def merge_grad_list(all_grads, all_vars): |