(self, graph: BaseGraph,
dataloader: Iterable, executor: BaseGraphExecutor, **kwargs)
| 251 | return a, b |
| 252 | |
| 253 | def optimize(self, graph: BaseGraph, |
| 254 | dataloader: Iterable, executor: BaseGraphExecutor, **kwargs) -> None: |
| 255 | splitting_layers = [] |
| 256 | for name in self.interested_layers: |
| 257 | if name not in graph.operations: |
| 258 | raise ValueError(f'Can not Split layer {name}, can not find it in current graph.') |
| 259 | |
| 260 | for operation in graph.operations.values(): |
| 261 | if operation.name in self.interested_layers: |
| 262 | assert operation.type in {'Conv', 'Gemm'}, ( |
| 263 | f'Can not split layer, cause layer type is not support') |
| 264 | splitting_layers.append(operation) |
| 265 | |
| 266 | for operation in splitting_layers: |
| 267 | assert isinstance(operation, Operation) |
| 268 | if operation.type == 'Gemm': |
| 269 | w = operation.parameters[0].value |
| 270 | w = w.transpose(0, 1) |
| 271 | if self.method == 'svd': |
| 272 | a, b = self.svd_for_factorization(w) |
| 273 | elif self.method == 'training': |
| 274 | a, b = self.train_for_factorization(w) |
| 275 | else: raise ValueError(f'Invalid method {self.method}, only support training and svd now.') |
| 276 | a = a.transpose(0, 1) |
| 277 | b = b.transpose(0, 1) |
| 278 | elif operation.type == 'Conv': |
| 279 | if operation.attributes['kernel_shape'] != [1, 1]: |
| 280 | raise PermissionError(f'Can not split layer {operation.name}, cause it kernel shape is not [1, 1]') |
| 281 | w = operation.parameters[0].value |
| 282 | assert isinstance(w, torch.Tensor) |
| 283 | w = w.squeeze(-1).squeeze(-1).transpose(0, 1) |
| 284 | print(w.shape) |
| 285 | if self.method == 'svd': |
| 286 | a, b = self.svd_for_factorization(w) |
| 287 | elif self.method == 'training': |
| 288 | a, b = self.train_for_factorization(w) |
| 289 | else: raise ValueError(f'Invalid method {self.method}, only support training and svd now.') |
| 290 | a = a.transpose(0, 1).unsqueeze(-1).unsqueeze(-1) |
| 291 | b = b.transpose(0, 1).unsqueeze(-1).unsqueeze(-1) |
| 292 | else: raise TypeError(f'Unsupported operation type {operation.type}.') |
| 293 | operation.parameters[0].value = a |
| 294 | |
| 295 | # create new operation & dirty work |
| 296 | attributes = {} |
| 297 | if operation.type == 'Conv': |
| 298 | attributes['kernel_shape'] = [1, 1] |
| 299 | attributes['pads'] = [0, 0, 0, 0] |
| 300 | attributes['strides'] = [1, 1] |
| 301 | attributes['dilations'] = [1, 1] |
| 302 | attributes['group'] = 1 |
| 303 | |
| 304 | if operation.type == 'Gemm': |
| 305 | attributes['alpha'] = 1 |
| 306 | attributes['beta'] = 1 |
| 307 | attributes['transB'] = 1 |
| 308 | |
| 309 | splitted = Operation( |
| 310 | name=operation.name + '_splited', |
nothing calls this directly
no test coverage detected