MCPcopy
hub / github.com/OpenPPL/ppq / write_back

Method write_back

ppq/quantization/optim/ssd.py:212–262  ·  view source on GitHub ↗
(self, pair: List[Operation], scale: torch.Tensor)

Source from the content-addressed store, hash-verified

210 return first_weight_range, last_weight_range
211
212 def write_back(self, pair: List[Operation], scale: torch.Tensor) -> None:
213 first_computing_op_weight = pair[0].parameters[0].value
214 last_computing_op_weight = pair[-1].parameters[0].value
215
216 assert isinstance(first_computing_op_weight, torch.Tensor)
217 assert isinstance(last_computing_op_weight, torch.Tensor)
218
219 if pair[0].type == 'Conv':
220 pair[0].parameters[0].value = first_computing_op_weight * scale.reshape(-1, 1, 1, 1)
221
222 elif pair[0].type == 'Gemm':
223 if pair[0].attributes.get('transB', 0):
224 pair[0].parameters[0].value = first_computing_op_weight * scale.reshape(-1, 1)
225 else:
226 pair[0].parameters[0].value = first_computing_op_weight * scale.reshape(1, -1)
227
228 elif pair[0].type == 'ConvTranspose':
229 num_group = pair[0].attributes.get('group', 1)
230 C_in, C_out_g, K1, K2 = first_computing_op_weight.shape
231 first_computing_op_weight = first_computing_op_weight.reshape(num_group, C_in // num_group, C_out_g, K1, K2)
232 first_computing_op_weight = first_computing_op_weight * scale.reshape(num_group, 1, -1, 1, 1)
233 pair[0].parameters[0].value = first_computing_op_weight.reshape(C_in, C_out_g, K1, K2)
234
235 if len(pair[0].parameters) > 1:
236 pair[0].parameters[1].value = pair[0].parameters[1].value * scale
237
238 if pair[-1].type == 'Conv':
239 num_group = pair[-1].attributes.get('group', 1)
240 C_out, C_in_g, K1, K2 = last_computing_op_weight.shape
241 last_computing_op_weight = last_computing_op_weight.reshape(num_group, C_out // num_group, C_in_g, K1, K2)
242 last_computing_op_weight = last_computing_op_weight / scale.reshape(num_group, 1, -1, 1, 1)
243 pair[-1].parameters[0].value = last_computing_op_weight.reshape(C_out, C_in_g, K1, K2)
244
245 elif pair[-1].type == 'Gemm':
246 if pair[-1].attributes.get('transB', 0):
247 if scale.numel() != last_computing_op_weight.shape[1]:
248 last_computing_op_weight = last_computing_op_weight.reshape(last_computing_op_weight.shape[0], scale.numel(), -1)
249 last_computing_op_weight = last_computing_op_weight / scale.reshape(1, -1, 1)
250 pair[-1].parameters[0].value = last_computing_op_weight.reshape(last_computing_op_weight.shape[0], -1)
251 else:
252 pair[-1].parameters[0].value = last_computing_op_weight / scale.reshape(1, -1)
253 else:
254 if scale.numel() != last_computing_op_weight.shape[0]:
255 last_computing_op_weight = last_computing_op_weight.reshape(scale.numel(), -1, last_computing_op_weight.shape[-1])
256 last_computing_op_weight = last_computing_op_weight / scale.reshape(-1, 1, 1)
257 pair[-1].parameters[0].value = last_computing_op_weight.reshape(-1, last_computing_op_weight.shape[-1])
258 else:
259 pair[-1].parameters[0].value = last_computing_op_weight / scale.reshape(-1, 1)
260
261 elif pair[-1].type == 'ConvTranspose':
262 pair[-1].parameters[0].value = last_computing_op_weight / scale.reshape(-1, 1, 1, 1)
263
264 def one_step_equalization(
265 self,

Callers 1

one_step_equalizationMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected