(self, r: torch.Tensor)
| 53 | super().__init__() |
| 54 | |
| 55 | def rectified_sigmoid(self, r: torch.Tensor) -> torch.Tensor: |
| 56 | return ((self.zeta - self.gamma) * torch.sigmoid(r) + self.gamma).clamp(0, 1) |
| 57 | |
| 58 | def forward(self, r: torch.Tensor, iter: int) -> torch.Tensor: |
| 59 | if iter < self.max_iter * self.warm_ratio: |