Method
__init__
(self, local_rank=-1, arbitrary=False)
Source from the content-addressed store, hash-verified
| 17 | |
| 18 | class Model: |
| 19 | def __init__(self, local_rank=-1, arbitrary=False): |
| 20 | if arbitrary == True: |
| 21 | self.flownet = IFNet_m() |
| 22 | else: |
| 23 | self.flownet = IFNet() |
| 24 | self.device() |
| 25 | self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-3) # use large weight decay may avoid NaN loss |
| 26 | self.epe = EPE() |
| 27 | self.lap = LapLoss() |
| 28 | self.sobel = SOBEL() |
| 29 | if local_rank != -1: |
| 30 | self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) |
| 31 | |
| 32 | def train(self): |
| 33 | self.flownet.train() |
Callers
nothing calls this directly
Tested by
no test coverage detected