| 75 | |
| 76 | |
| 77 | def load_data_for_worker(base_samples, batch_size, class_cond): |
| 78 | with bf.BlobFile(base_samples, "rb") as f: |
| 79 | obj = np.load(f) |
| 80 | image_arr = obj["arr_0"] |
| 81 | if class_cond: |
| 82 | label_arr = obj["arr_1"] |
| 83 | rank = dist.get_rank() |
| 84 | num_ranks = dist.get_world_size() |
| 85 | buffer = [] |
| 86 | label_buffer = [] |
| 87 | while True: |
| 88 | for i in range(rank, len(image_arr), num_ranks): |
| 89 | buffer.append(image_arr[i]) |
| 90 | if class_cond: |
| 91 | label_buffer.append(label_arr[i]) |
| 92 | if len(buffer) == batch_size: |
| 93 | batch = th.from_numpy(np.stack(buffer)).float() |
| 94 | batch = batch / 127.5 - 1.0 |
| 95 | batch = batch.permute(0, 3, 1, 2) |
| 96 | res = dict(low_res=batch) |
| 97 | if class_cond: |
| 98 | res["y"] = th.from_numpy(np.stack(label_buffer)) |
| 99 | yield res |
| 100 | buffer, label_buffer = [], [] |
| 101 | |
| 102 | |
| 103 | def create_argparser(): |