()
| 120 | |
| 121 | |
| 122 | def ram_read_write_worker(): |
| 123 | schema = { |
| 124 | "tf": SizeData(dtype="float32", shape=(112, 112)), |
| 125 | "ti": SizeData(dtype="int32", shape=(4, 64, 64)), |
| 126 | } |
| 127 | storage = SingleProcessRamTensorStorage(schema, io.BytesIO()) |
| 128 | world_size = comm.get_world_size() |
| 129 | rank = comm.get_rank() |
| 130 | data_elts = [] |
| 131 | # prepare different number of tensors in different processes |
| 132 | for i in range(rank + 1): |
| 133 | data_elt = { |
| 134 | "tf": torch.ones((112, 112), dtype=torch.float32) * (rank + i * world_size), |
| 135 | "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (rank + i * world_size), |
| 136 | } |
| 137 | data_elts.append(data_elt) |
| 138 | # write data to the single process storage |
| 139 | for i in range(rank + 1): |
| 140 | record_id = storage.put(data_elts[i]) |
| 141 | assert record_id == i, f"Process {rank}: record ID {record_id}, expected {i}" |
| 142 | comm.synchronize() |
| 143 | # gather all data in process rank 0 |
| 144 | multi_storage = storage_gather(storage) |
| 145 | if rank != 0: |
| 146 | return |
| 147 | # read and check data from the multiprocess storage |
| 148 | for j in range(world_size): |
| 149 | for i in range(j): |
| 150 | record = multi_storage.get(j, i) |
| 151 | record_gt = { |
| 152 | "tf": torch.ones((112, 112), dtype=torch.float32) * (j + i * world_size), |
| 153 | "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (j + i * world_size), |
| 154 | } |
| 155 | assert len(record) == len(schema), ( |
| 156 | f"Process {rank}: multi storage record, rank {j}, id {i}: " |
| 157 | f"expected {len(schema)} fields in the record, got {len(record)}" |
| 158 | ) |
| 159 | for field_name in schema: |
| 160 | assert field_name in record, ( |
| 161 | f"Process {rank}: multi storage record, rank {j}, id {i}: " |
| 162 | f"field {field_name} not in the record" |
| 163 | ) |
| 164 | |
| 165 | assert record_gt[field_name].shape == record[field_name].shape, ( |
| 166 | f"Process {rank}: multi storage record, rank {j}, id {i}: " |
| 167 | f"field {field_name}, expected shape {record_gt[field_name].shape} " |
| 168 | f"got {record[field_name].shape}" |
| 169 | ) |
| 170 | assert record_gt[field_name].dtype == record[field_name].dtype, ( |
| 171 | f"Process {rank}: multi storage record, rank {j}, id {i}: " |
| 172 | f"field {field_name}, expected dtype {record_gt[field_name].dtype} " |
| 173 | f"got {record[field_name].dtype}" |
| 174 | ) |
| 175 | assert torch.allclose(record_gt[field_name], record[field_name]), ( |
| 176 | f"Process {rank}: multi storage record, rank {j}, id {i}: " |
| 177 | f"field {field_name}, tensors are not close enough:" |
| 178 | f"L-inf {(record_gt[field_name]-record[field_name]).abs_().max()} " |
| 179 | f"L1 {(record_gt[field_name]-record[field_name]).abs_().sum()} " |
nothing calls this directly
no test coverage detected