(self, path, rank=0)
| 35 | self.flownet.to(device) |
| 36 | |
| 37 | def load_model(self, path, rank=0): |
| 38 | def convert(param): |
| 39 | return {k.replace("module.", ""): v for k, v in param.items() if "module." in k} |
| 40 | |
| 41 | if rank <= 0: |
| 42 | self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path)))) |
| 43 | |
| 44 | def save_model(self, path, rank=0): |
| 45 | if rank == 0: |
no test coverage detected