(self, mode=True)
| 664 | param.requires_grad = False |
| 665 | |
| 666 | def train(self, mode=True): |
| 667 | super(Vig, self).train(mode) |
| 668 | self._freeze_stages() |
| 669 | if mode and self.norm_eval: |
| 670 | for m in self.modules(): |
| 671 | # trick: eval have effect on BatchNorm only |
| 672 | if isinstance(m, _BatchNorm): |
| 673 | m.eval() |
| 674 | |
| 675 | |
| 676 | @MODELS.register_module() |