| 135 | |
| 136 | |
| 137 | def h_split(self, op: Operation) -> Tuple[torch.Tensor, torch.Tensor, int]: |
| 138 | # split weight |
| 139 | value = op.inputs[1].value |
| 140 | mask = (value.abs() > self.value_threshold) |
| 141 | processed_values = mask.sum().item() |
| 142 | |
| 143 | s_value = value |
| 144 | if self.method == 'balance': |
| 145 | s_value = (value / 2) * mask |
| 146 | elif self.method == 'random': |
| 147 | s_value = (value * torch.rand_like(value)) * mask |
| 148 | else: raise Exception('Oops, seems we got some troubles here.') |
| 149 | r_value = value - s_value |
| 150 | |
| 151 | # print |
| 152 | if self.verbose: |
| 153 | print('') |
| 154 | print(f'# Layer {op.name} has been splited, ' |
| 155 | f'{processed_values}/{value.numel()} value(s) was processed.') |
| 156 | return r_value, s_value, processed_values |
| 157 | |
| 158 | def optimize(self, graph: BaseGraph, |
| 159 | dataloader: Iterable, executor: BaseGraphExecutor, |