| 15 | |
| 16 | |
| 17 | class LambdaLR: |
| 18 | def __init__(self, n_epochs, offset, decay_start_epoch): |
| 19 | assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!" |
| 20 | self.n_epochs = n_epochs |
| 21 | self.offset = offset |
| 22 | self.decay_start_epoch = decay_start_epoch |
| 23 | |
| 24 | def step(self, epoch): |
| 25 | return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch) |
| 26 | |
| 27 | |
| 28 | ############################## |