| 114 | self.init_cache() |
| 115 | |
| 116 | def init_cache(self): |
| 117 | assert self.cache_mode in ["part", "full"] |
| 118 | n_sample = len(self.samples) |
| 119 | global_rank = dist.get_rank() |
| 120 | world_size = dist.get_world_size() |
| 121 | |
| 122 | samples_bytes = [None for _ in range(n_sample)] |
| 123 | start_time = time.time() |
| 124 | for index in range(n_sample): |
| 125 | if index % (n_sample // 10) == 0: |
| 126 | t = time.time() - start_time |
| 127 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') |
| 128 | start_time = time.time() |
| 129 | path, target = self.samples[index] |
| 130 | if self.cache_mode == "full": |
| 131 | samples_bytes[index] = (ZipReader.read(path), target) |
| 132 | elif self.cache_mode == "part" and index % world_size == global_rank: |
| 133 | samples_bytes[index] = (ZipReader.read(path), target) |
| 134 | else: |
| 135 | samples_bytes[index] = (path, target) |
| 136 | self.samples = samples_bytes |
| 137 | |
| 138 | def __getitem__(self, index): |
| 139 | """ |