(
self, config: TensorQuantizationConfig, var: Variable,
is_parameter_trainable: bool = True,
is_scale_trainable: bool = True,
is_offset_trainable: bool = True
)
| 317 | |
| 318 | class LSQDelegator(TorchQuantizeDelegator): |
| 319 | def __init__( |
| 320 | self, config: TensorQuantizationConfig, var: Variable, |
| 321 | is_parameter_trainable: bool = True, |
| 322 | is_scale_trainable: bool = True, |
| 323 | is_offset_trainable: bool = True |
| 324 | ) -> None: |
| 325 | self.config = config |
| 326 | self.is_parameter = var.is_parameter |
| 327 | self.var = var |
| 328 | self.policy = config.policy |
| 329 | self.passive = config.state == QuantizationStates.PASSIVE |
| 330 | |
| 331 | self.param_backup = None |
| 332 | if self.is_parameter and is_parameter_trainable: |
| 333 | self.param_backup = self.var.value.clone() |
| 334 | |
| 335 | # There is 4 checks for scale training: |
| 336 | # 1. scale is valid |
| 337 | # 2. state is active |
| 338 | # 3. do not have POWER_OF_2 policy but Must have Linear policy |
| 339 | # 4. is_scale_trainable = True |
| 340 | self.scale_backup = None |
| 341 | self.is_scale_trainable = False |
| 342 | if is_scale_trainable: |
| 343 | policy_check = not config.policy.has_property(QuantizationProperty.POWER_OF_2) |
| 344 | linear_check = config.policy.has_property(QuantizationProperty.LINEAR) |
| 345 | state_check = ((config.state == QuantizationStates.ACTIVATED) and (config.dominated_by == config)) |
| 346 | value_check = isinstance(config.scale, torch.Tensor) |
| 347 | if policy_check and state_check and value_check and linear_check: |
| 348 | self.is_scale_trainable = True |
| 349 | self.scale_backup = self.config.scale.detach().clone() |
| 350 | |
| 351 | # There is 4 checks for offset training: |
| 352 | # 1. offset is valid |
| 353 | # 2. state is active |
| 354 | # 3. do not have SYMMETRICAL policy |
| 355 | # 4. is_scale_trainable = True |
| 356 | self.offset_backup = None |
| 357 | self.is_offset_trainable = False |
| 358 | if is_offset_trainable: |
| 359 | policy_check = not config.policy.has_property(QuantizationProperty.SYMMETRICAL) |
| 360 | state_check = ((config.state == QuantizationStates.ACTIVATED) and (config.dominated_by == config)) |
| 361 | value_check = isinstance(config.offset, torch.Tensor) |
| 362 | if policy_check and state_check and value_check: |
| 363 | self.is_offset_trainable = True |
| 364 | self.offset_backup = self.config.offset.detach().clone() |
| 365 | |
| 366 | def trainable_tensors(self) -> List[torch.Tensor]: |
| 367 | params = [] |
nothing calls this directly
no test coverage detected