| 128 | |
| 129 | class Model: |
| 130 | def __init__(self, local_rank=-1): |
| 131 | self.flownet = IFNet() |
| 132 | self.contextnet = ContextNet() |
| 133 | self.fusionnet = FusionNet() |
| 134 | self.device() |
| 135 | self.optimG = AdamW(itertools.chain( |
| 136 | self.flownet.parameters(), |
| 137 | self.contextnet.parameters(), |
| 138 | self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4) |
| 139 | self.schedulerG = optim.lr_scheduler.CyclicLR( |
| 140 | self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) |
| 141 | self.epe = EPE() |
| 142 | self.ter = Ternary() |
| 143 | self.sobel = SOBEL() |
| 144 | if local_rank != -1: |
| 145 | self.flownet = DDP(self.flownet, device_ids=[ |
| 146 | local_rank], output_device=local_rank) |
| 147 | self.contextnet = DDP(self.contextnet, device_ids=[ |
| 148 | local_rank], output_device=local_rank) |
| 149 | self.fusionnet = DDP(self.fusionnet, device_ids=[ |
| 150 | local_rank], output_device=local_rank) |
| 151 | |
| 152 | def train(self): |
| 153 | self.flownet.train() |