| 575 | ) |
| 576 | |
| 577 | def load(self): |
| 578 | with open( |
| 579 | os.path.join(self.save_path, "info_{}.json".format(self.hash)), "r" |
| 580 | ) as f: |
| 581 | info = json.load(f) |
| 582 | if info["split_ratio"] != self.split_ratio: |
| 583 | raise ValueError( |
| 584 | "Provided split ratio is different from the cached file. " |
| 585 | "Re-process the dataset." |
| 586 | ) |
| 587 | self.split_ratio = info["split_ratio"] |
| 588 | self.num_tasks = info["num_tasks"] |
| 589 | self.num_classes = info["num_classes"] |
| 590 | |
| 591 | split = np.load( |
| 592 | os.path.join(self.save_path, "split_{}.npz".format(self.hash)) |
| 593 | ) |
| 594 | self.train_idx = F.zerocopy_from_numpy(split["train_idx"]) |
| 595 | self.val_idx = F.zerocopy_from_numpy(split["val_idx"]) |
| 596 | self.test_idx = F.zerocopy_from_numpy(split["test_idx"]) |
| 597 | |
| 598 | def save(self): |
| 599 | if not os.path.exists(self.save_path): |