| 14 | |
| 15 | |
| 16 | class DatasetWrapper(torch.utils.data.Dataset): |
| 17 | def __init__(self, dataset, logits_path, topk, write): |
| 18 | super().__init__() |
| 19 | self.dataset = dataset |
| 20 | self.logits_path = logits_path |
| 21 | self.epoch = multiprocessing.Value('i', 0) |
| 22 | self.topk = topk |
| 23 | self.write_mode = write |
| 24 | self.keys = self._get_keys() |
| 25 | self._manager = (None, None) |
| 26 | |
| 27 | def __getitem__(self, index: int): |
| 28 | if self.write_mode: |
| 29 | return self.__getitem_for_write(index) |
| 30 | return self.__getitem_for_read(index) |
| 31 | |
| 32 | def __getitem_for_write(self, index: int): |
| 33 | # get an augmentation seed |
| 34 | key = self.keys[index] |
| 35 | seed = np.int32(np.random.randint(0, 1 << 31)) |
| 36 | with AugRandomContext(seed=int(seed)): |
| 37 | item = self.dataset[index] |
| 38 | return (item, (key, seed)) |
| 39 | |
| 40 | def __getitem_for_read(self, index: int): |
| 41 | key = self.keys[index] |
| 42 | seed, logits_index, logits_value = self._get_saved_logits(key) |
| 43 | with AugRandomContext(seed=seed): |
| 44 | item = self.dataset[index] |
| 45 | return (item, (logits_index, logits_value, np.int32(seed))) |
| 46 | |
| 47 | def _get_saved_logits(self, key: str): |
| 48 | manager = self.get_manager() |
| 49 | bstr: bytes = manager.read(key) |
| 50 | # parse the augmentation seed |
| 51 | seed = int(np.frombuffer(bstr[:4], dtype=np.int32)) |
| 52 | # parse the logits index and value |
| 53 | # copy logits_index and logits_value to avoid warning of written flag from PyTorch |
| 54 | bstr = bstr[4:] |
| 55 | logits_index = np.frombuffer( |
| 56 | bstr[:self.topk * 2], dtype=np.int16).copy() |
| 57 | bstr = bstr[self.topk * 2:] |
| 58 | logits_value = np.frombuffer( |
| 59 | bstr[:self.topk * 2], dtype=np.float16).copy() |
| 60 | return seed, logits_index, logits_value |
| 61 | |
| 62 | def _build_manager(self, logits_path: str): |
| 63 | # topk * [idx, value] * 2 bytes for logits + 4 bytes for seed |
| 64 | item_size = self.topk * 2 * 2 + 4 |
| 65 | rank = get_rank() |
| 66 | return TxtManager(logits_path, item_size, rank) |
| 67 | |
| 68 | def set_epoch(self, epoch: int): |
| 69 | self.epoch.value = epoch |
| 70 | self._manager = (None, None) |
| 71 | |
| 72 | def get_manager(self): |
| 73 | epoch = self.epoch.value |