(self)
| 79 | return batch |
| 80 | |
| 81 | def __next__(self): |
| 82 | batch = self._next_indices() |
| 83 | if self.mapping_keys is None: |
| 84 | # clone() fixes #3755, probably. Not sure why. Need to take a look afterwards. |
| 85 | return batch.clone() |
| 86 | |
| 87 | # convert the type-ID pairs to dictionary |
| 88 | type_ids = batch[:, 0] |
| 89 | indices = batch[:, 1] |
| 90 | _, type_ids_sortidx = torch.sort(type_ids, stable=True) |
| 91 | type_ids = type_ids[type_ids_sortidx] |
| 92 | indices = indices[type_ids_sortidx] |
| 93 | type_id_uniq, type_id_count = torch.unique_consecutive( |
| 94 | type_ids, return_counts=True |
| 95 | ) |
| 96 | type_id_uniq = type_id_uniq.tolist() |
| 97 | type_id_offset = type_id_count.cumsum(0).tolist() |
| 98 | type_id_offset.insert(0, 0) |
| 99 | id_dict = { |
| 100 | self.mapping_keys[type_id_uniq[i]]: indices[ |
| 101 | type_id_offset[i] : type_id_offset[i + 1] |
| 102 | ].clone() |
| 103 | for i in range(len(type_id_uniq)) |
| 104 | } |
| 105 | return id_dict |
| 106 | |
| 107 | |
| 108 | def _get_id_tensor_from_mapping(indices, device, keys): |
nothing calls this directly
no test coverage detected