(self, path, rank=0)
| 39 | self.flownet.to(device) |
| 40 | |
| 41 | def load_model(self, path, rank=0): |
| 42 | def convert(param): |
| 43 | return { |
| 44 | k.replace("module.", ""): v |
| 45 | for k, v in param.items() |
| 46 | if "module." in k |
| 47 | } |
| 48 | |
| 49 | if rank <= 0: |
| 50 | self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path)))) |
| 51 | |
| 52 | def save_model(self, path, rank=0): |
| 53 | if rank == 0: |
no outgoing calls
no test coverage detected