(self, pair: List[Operation], scale: torch.Tensor)
| 210 | return first_weight_range, last_weight_range |
| 211 | |
| 212 | def write_back(self, pair: List[Operation], scale: torch.Tensor) -> None: |
| 213 | first_computing_op_weight = pair[0].parameters[0].value |
| 214 | last_computing_op_weight = pair[-1].parameters[0].value |
| 215 | |
| 216 | assert isinstance(first_computing_op_weight, torch.Tensor) |
| 217 | assert isinstance(last_computing_op_weight, torch.Tensor) |
| 218 | |
| 219 | if pair[0].type == 'Conv': |
| 220 | pair[0].parameters[0].value = first_computing_op_weight * scale.reshape(-1, 1, 1, 1) |
| 221 | |
| 222 | elif pair[0].type == 'Gemm': |
| 223 | if pair[0].attributes.get('transB', 0): |
| 224 | pair[0].parameters[0].value = first_computing_op_weight * scale.reshape(-1, 1) |
| 225 | else: |
| 226 | pair[0].parameters[0].value = first_computing_op_weight * scale.reshape(1, -1) |
| 227 | |
| 228 | elif pair[0].type == 'ConvTranspose': |
| 229 | num_group = pair[0].attributes.get('group', 1) |
| 230 | C_in, C_out_g, K1, K2 = first_computing_op_weight.shape |
| 231 | first_computing_op_weight = first_computing_op_weight.reshape(num_group, C_in // num_group, C_out_g, K1, K2) |
| 232 | first_computing_op_weight = first_computing_op_weight * scale.reshape(num_group, 1, -1, 1, 1) |
| 233 | pair[0].parameters[0].value = first_computing_op_weight.reshape(C_in, C_out_g, K1, K2) |
| 234 | |
| 235 | if len(pair[0].parameters) > 1: |
| 236 | pair[0].parameters[1].value = pair[0].parameters[1].value * scale |
| 237 | |
| 238 | if pair[-1].type == 'Conv': |
| 239 | num_group = pair[-1].attributes.get('group', 1) |
| 240 | C_out, C_in_g, K1, K2 = last_computing_op_weight.shape |
| 241 | last_computing_op_weight = last_computing_op_weight.reshape(num_group, C_out // num_group, C_in_g, K1, K2) |
| 242 | last_computing_op_weight = last_computing_op_weight / scale.reshape(num_group, 1, -1, 1, 1) |
| 243 | pair[-1].parameters[0].value = last_computing_op_weight.reshape(C_out, C_in_g, K1, K2) |
| 244 | |
| 245 | elif pair[-1].type == 'Gemm': |
| 246 | if pair[-1].attributes.get('transB', 0): |
| 247 | if scale.numel() != last_computing_op_weight.shape[1]: |
| 248 | last_computing_op_weight = last_computing_op_weight.reshape(last_computing_op_weight.shape[0], scale.numel(), -1) |
| 249 | last_computing_op_weight = last_computing_op_weight / scale.reshape(1, -1, 1) |
| 250 | pair[-1].parameters[0].value = last_computing_op_weight.reshape(last_computing_op_weight.shape[0], -1) |
| 251 | else: |
| 252 | pair[-1].parameters[0].value = last_computing_op_weight / scale.reshape(1, -1) |
| 253 | else: |
| 254 | if scale.numel() != last_computing_op_weight.shape[0]: |
| 255 | last_computing_op_weight = last_computing_op_weight.reshape(scale.numel(), -1, last_computing_op_weight.shape[-1]) |
| 256 | last_computing_op_weight = last_computing_op_weight / scale.reshape(-1, 1, 1) |
| 257 | pair[-1].parameters[0].value = last_computing_op_weight.reshape(-1, last_computing_op_weight.shape[-1]) |
| 258 | else: |
| 259 | pair[-1].parameters[0].value = last_computing_op_weight / scale.reshape(-1, 1) |
| 260 | |
| 261 | elif pair[-1].type == 'ConvTranspose': |
| 262 | pair[-1].parameters[0].value = last_computing_op_weight / scale.reshape(-1, 1, 1, 1) |
| 263 | |
| 264 | def one_step_equalization( |
| 265 | self, |
no outgoing calls
no test coverage detected