MCPcopy
hub / github.com/hustvl/Vim / ram_read_write_worker

Function ram_read_write_worker

det/projects/DensePose/tests/test_tensor_storage.py:122–180  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

120
121
122def ram_read_write_worker():
123 schema = {
124 "tf": SizeData(dtype="float32", shape=(112, 112)),
125 "ti": SizeData(dtype="int32", shape=(4, 64, 64)),
126 }
127 storage = SingleProcessRamTensorStorage(schema, io.BytesIO())
128 world_size = comm.get_world_size()
129 rank = comm.get_rank()
130 data_elts = []
131 # prepare different number of tensors in different processes
132 for i in range(rank + 1):
133 data_elt = {
134 "tf": torch.ones((112, 112), dtype=torch.float32) * (rank + i * world_size),
135 "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (rank + i * world_size),
136 }
137 data_elts.append(data_elt)
138 # write data to the single process storage
139 for i in range(rank + 1):
140 record_id = storage.put(data_elts[i])
141 assert record_id == i, f"Process {rank}: record ID {record_id}, expected {i}"
142 comm.synchronize()
143 # gather all data in process rank 0
144 multi_storage = storage_gather(storage)
145 if rank != 0:
146 return
147 # read and check data from the multiprocess storage
148 for j in range(world_size):
149 for i in range(j):
150 record = multi_storage.get(j, i)
151 record_gt = {
152 "tf": torch.ones((112, 112), dtype=torch.float32) * (j + i * world_size),
153 "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (j + i * world_size),
154 }
155 assert len(record) == len(schema), (
156 f"Process {rank}: multi storage record, rank {j}, id {i}: "
157 f"expected {len(schema)} fields in the record, got {len(record)}"
158 )
159 for field_name in schema:
160 assert field_name in record, (
161 f"Process {rank}: multi storage record, rank {j}, id {i}: "
162 f"field {field_name} not in the record"
163 )
164
165 assert record_gt[field_name].shape == record[field_name].shape, (
166 f"Process {rank}: multi storage record, rank {j}, id {i}: "
167 f"field {field_name}, expected shape {record_gt[field_name].shape} "
168 f"got {record[field_name].shape}"
169 )
170 assert record_gt[field_name].dtype == record[field_name].dtype, (
171 f"Process {rank}: multi storage record, rank {j}, id {i}: "
172 f"field {field_name}, expected dtype {record_gt[field_name].dtype} "
173 f"got {record[field_name].dtype}"
174 )
175 assert torch.allclose(record_gt[field_name], record[field_name]), (
176 f"Process {rank}: multi storage record, rank {j}, id {i}: "
177 f"field {field_name}, tensors are not close enough:"
178 f"L-inf {(record_gt[field_name]-record[field_name]).abs_().max()} "
179 f"L1 {(record_gt[field_name]-record[field_name]).abs_().sum()} "

Callers

nothing calls this directly

Calls 6

SizeDataClass · 0.90
storage_gatherFunction · 0.90
maxMethod · 0.80
putMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected