| 113 | |
| 114 | class Model: |
| 115 | def __init__(self, local_rank=-1): |
| 116 | self.flownet = IFNet() |
| 117 | self.contextnet = ContextNet() |
| 118 | self.fusionnet = FusionNet() |
| 119 | self.device() |
| 120 | self.optimG = AdamW(itertools.chain( |
| 121 | self.flownet.parameters(), |
| 122 | self.contextnet.parameters(), |
| 123 | self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4) |
| 124 | self.schedulerG = optim.lr_scheduler.CyclicLR( |
| 125 | self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) |
| 126 | self.epe = EPE() |
| 127 | self.ter = Ternary() |
| 128 | self.sobel = SOBEL() |
| 129 | if local_rank != -1: |
| 130 | self.flownet = DDP(self.flownet, device_ids=[ |
| 131 | local_rank], output_device=local_rank) |
| 132 | self.contextnet = DDP(self.contextnet, device_ids=[ |
| 133 | local_rank], output_device=local_rank) |
| 134 | self.fusionnet = DDP(self.fusionnet, device_ids=[ |
| 135 | local_rank], output_device=local_rank) |
| 136 | |
| 137 | def train(self): |
| 138 | self.flownet.train() |