| 124 | |
| 125 | |
| 126 | def _split_to_local_id_tensor_from_mapping( |
| 127 | indices, keys, local_lower_bound, local_upper_bound |
| 128 | ): |
| 129 | dtype = dtype_of(indices) |
| 130 | device = next(iter(indices.values())).device |
| 131 | num_samples = local_upper_bound - local_lower_bound |
| 132 | id_tensor = torch.empty(num_samples, 2, dtype=dtype, device=device) |
| 133 | |
| 134 | index_offset = 0 |
| 135 | split_id_offset = 0 |
| 136 | for i, k in enumerate(keys): |
| 137 | if k not in indices: |
| 138 | continue |
| 139 | index = indices[k] |
| 140 | length = index.shape[0] |
| 141 | index_offset2 = index_offset + length |
| 142 | lower = max(local_lower_bound, index_offset) |
| 143 | upper = min(local_upper_bound, index_offset2) |
| 144 | if upper > lower: |
| 145 | split_id_offset2 = split_id_offset + (upper - lower) |
| 146 | assert split_id_offset2 <= num_samples |
| 147 | id_tensor[split_id_offset:split_id_offset2, 0] = i |
| 148 | id_tensor[split_id_offset:split_id_offset2, 1] = index[ |
| 149 | lower - index_offset : upper - index_offset |
| 150 | ] |
| 151 | split_id_offset += upper - lower |
| 152 | if split_id_offset2 == num_samples: |
| 153 | break |
| 154 | index_offset = index_offset2 |
| 155 | return id_tensor |
| 156 | |
| 157 | |
| 158 | def _split_to_local_id_tensor(indices, local_lower_bound, local_upper_bound): |