MCPcopy
hub / github.com/openai/guided-diffusion / load_data_for_worker

Function load_data_for_worker

scripts/super_res_sample.py:77–100  ·  view source on GitHub ↗
(base_samples, batch_size, class_cond)

Source from the content-addressed store, hash-verified

75
76
77def load_data_for_worker(base_samples, batch_size, class_cond):
78 with bf.BlobFile(base_samples, "rb") as f:
79 obj = np.load(f)
80 image_arr = obj["arr_0"]
81 if class_cond:
82 label_arr = obj["arr_1"]
83 rank = dist.get_rank()
84 num_ranks = dist.get_world_size()
85 buffer = []
86 label_buffer = []
87 while True:
88 for i in range(rank, len(image_arr), num_ranks):
89 buffer.append(image_arr[i])
90 if class_cond:
91 label_buffer.append(label_arr[i])
92 if len(buffer) == batch_size:
93 batch = th.from_numpy(np.stack(buffer)).float()
94 batch = batch / 127.5 - 1.0
95 batch = batch.permute(0, 3, 1, 2)
96 res = dict(low_res=batch)
97 if class_cond:
98 res["y"] = th.from_numpy(np.stack(label_buffer))
99 yield res
100 buffer, label_buffer = [], []
101
102
103def create_argparser():

Callers 1

mainFunction · 0.85

Calls 1

loadMethod · 0.80

Tested by

no test coverage detected