(self, r: torch.Tensor, iter: int)
| 56 | return ((self.zeta - self.gamma) * torch.sigmoid(r) + self.gamma).clamp(0, 1) |
| 57 | |
| 58 | def forward(self, r: torch.Tensor, iter: int) -> torch.Tensor: |
| 59 | if iter < self.max_iter * self.warm_ratio: |
| 60 | round_loss = 0 |
| 61 | else: |
| 62 | self.beta = self.temp_anneal(iter) |
| 63 | round_loss = self.alpha * (1 - torch.pow((self.rectified_sigmoid(r) - 0.5).abs() * 2, self.beta)).sum() |
| 64 | return round_loss |
| 65 | |
| 66 | |
| 67 | class AdaRoundDelegator(TorchQuantizeDelegator): |