MCPcopy Index your code
hub / github.com/openai/improved-diffusion / load_data_for_worker

Function load_data_for_worker

scripts/super_res_sample.py:75–98  ·  view source on GitHub ↗
(base_samples, batch_size, class_cond)

Source from the content-addressed store, hash-verified

73
74
75def load_data_for_worker(base_samples, batch_size, class_cond):
76 with bf.BlobFile(base_samples, "rb") as f:
77 obj = np.load(f)
78 image_arr = obj["arr_0"]
79 if class_cond:
80 label_arr = obj["arr_1"]
81 rank = dist.get_rank()
82 num_ranks = dist.get_world_size()
83 buffer = []
84 label_buffer = []
85 while True:
86 for i in range(rank, len(image_arr), num_ranks):
87 buffer.append(image_arr[i])
88 if class_cond:
89 label_buffer.append(label_arr[i])
90 if len(buffer) == batch_size:
91 batch = th.from_numpy(np.stack(buffer)).float()
92 batch = batch / 127.5 - 1.0
93 batch = batch.permute(0, 3, 1, 2)
94 res = dict(low_res=batch)
95 if class_cond:
96 res["y"] = th.from_numpy(np.stack(label_buffer))
97 yield res
98 buffer, label_buffer = [], []
99
100
101def create_argparser():

Callers 1

mainFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected