| 156 | |
| 157 | |
| 158 | def _split_to_local_id_tensor(indices, local_lower_bound, local_upper_bound): |
| 159 | dtype = dtype_of(indices) |
| 160 | device = indices.device |
| 161 | num_samples = local_upper_bound - local_lower_bound |
| 162 | id_tensor = torch.empty(num_samples, dtype=dtype, device=device) |
| 163 | |
| 164 | if local_upper_bound > len(indices): |
| 165 | remainder = len(indices) - local_lower_bound |
| 166 | id_tensor[0:remainder] = indices[local_lower_bound:] |
| 167 | else: |
| 168 | id_tensor = indices[local_lower_bound:local_upper_bound] |
| 169 | return id_tensor |
| 170 | |
| 171 | |
| 172 | def _divide_by_worker(dataset, batch_size, drop_last): |