Equalization step with scale being calculated in the way specified by algo_type. Args: pair (List[Operation]): a list of operations representing a equalzation pair op_act_channel_range (Dict[Operation, torch.Tensor]): channel-wise activation range of all Conv
(
self,
pair: List[Operation],
op_act_channel_range: Dict[Operation, torch.Tensor]={},
algo_type: int=2,
ssd_min_scale: float=8,
ssd_max_scale: float=2,
dfq_min_scale: float=0.1,
dfq_max_scale: float=10,
eps: float=1e-8
)
| 262 | pair[-1].parameters[0].value = last_computing_op_weight / scale.reshape(-1, 1, 1, 1) |
| 263 | |
| 264 | def one_step_equalization( |
| 265 | self, |
| 266 | pair: List[Operation], |
| 267 | op_act_channel_range: Dict[Operation, torch.Tensor]={}, |
| 268 | algo_type: int=2, |
| 269 | ssd_min_scale: float=8, |
| 270 | ssd_max_scale: float=2, |
| 271 | dfq_min_scale: float=0.1, |
| 272 | dfq_max_scale: float=10, |
| 273 | eps: float=1e-8 |
| 274 | ): |
| 275 | """Equalization step with scale being calculated in the way specified |
| 276 | by algo_type. |
| 277 | |
| 278 | Args: |
| 279 | pair (List[Operation]): a list of operations representing a equalzation pair |
| 280 | op_act_channel_range (Dict[Operation, torch.Tensor]): channel-wise activation range of all Conv ops in the graph |
| 281 | algo_type (int, optional): minor algo type. 0 represents dfq algo, 1~3 represents ssd algo. Defaults to 2. |
| 282 | ssd_min_scale (float, optional): minimum clip value of scale for ssd algo. Defaults to 8. |
| 283 | ssd_max_scale (float, optional): maximum clip value of scale for ssd algo. Defaults to 2. |
| 284 | dfq_min_scale (float, optional): minimum clip value of scale for dfq algo. Defaults to 0.1. |
| 285 | dfq_max_scale (float, optional): maximum clip value of scale for dfq algo. Defaults to 10. |
| 286 | eps (float, optional): small constant for numerical stability. Defaults to 1e-8. |
| 287 | """ |
| 288 | first_weight_range, last_weight_range = self.prepare_weight_for_equalization(pair) |
| 289 | |
| 290 | if algo_type == 0: |
| 291 | # dfq |
| 292 | scale = torch.sqrt(last_weight_range / (first_weight_range + eps)) |
| 293 | scale = torch.clamp(scale, dfq_min_scale, dfq_max_scale) |
| 294 | |
| 295 | else: |
| 296 | first_weight_range = torch.where(first_weight_range < first_weight_range.max() * self.channel_ratio,\ |
| 297 | first_weight_range.max() * self.channel_ratio, first_weight_range) |
| 298 | last_weight_range = torch.where(last_weight_range < last_weight_range.max() * self.channel_ratio,\ |
| 299 | last_weight_range.max() * self.channel_ratio, last_weight_range) |
| 300 | |
| 301 | kernel_scale = first_weight_range.max() / (first_weight_range + eps) |
| 302 | next_kernel_scale = last_weight_range.max() / (last_weight_range + eps) |
| 303 | first_weight_act_range = op_act_channel_range[pair[0]] |
| 304 | first_weight_act_range = torch.where(first_weight_act_range < 0.01, torch.tensor(0.01,\ |
| 305 | device=first_weight_act_range.device, dtype=torch.float32), first_weight_act_range) |
| 306 | act_scale = first_weight_act_range.max() / (first_weight_act_range + eps) |
| 307 | |
| 308 | if algo_type == 1: |
| 309 | scale = torch.min(kernel_scale, act_scale) |
| 310 | elif algo_type == 2: |
| 311 | kernel_scale = kernel_scale / next_kernel_scale |
| 312 | act_scale = act_scale / next_kernel_scale |
| 313 | scale = torch.min(kernel_scale, act_scale) |
| 314 | scale = torch.min(scale, torch.tensor(ssd_min_scale, dtype=torch.float32, device=scale.device)) |
| 315 | scale /= scale.min() |
| 316 | scale = torch.clamp(scale, 1.0, ssd_max_scale) |
| 317 | else: |
| 318 | kernel_scale = (kernel_scale / next_kernel_scale).sqrt() |
| 319 | scale = (act_scale * kernel_scale).sqrt() |
| 320 | scale = torch.clamp(scale, 1.0, ssd_max_scale) |
| 321 |
no test coverage detected