MCPcopy
hub / github.com/microsoft/Cream / DatasetWrapper

Class DatasetWrapper

TinyViT/data/augmentation/dataset_wrapper.py:16–90  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

14
15
16class DatasetWrapper(torch.utils.data.Dataset):
17 def __init__(self, dataset, logits_path, topk, write):
18 super().__init__()
19 self.dataset = dataset
20 self.logits_path = logits_path
21 self.epoch = multiprocessing.Value('i', 0)
22 self.topk = topk
23 self.write_mode = write
24 self.keys = self._get_keys()
25 self._manager = (None, None)
26
27 def __getitem__(self, index: int):
28 if self.write_mode:
29 return self.__getitem_for_write(index)
30 return self.__getitem_for_read(index)
31
32 def __getitem_for_write(self, index: int):
33 # get an augmentation seed
34 key = self.keys[index]
35 seed = np.int32(np.random.randint(0, 1 << 31))
36 with AugRandomContext(seed=int(seed)):
37 item = self.dataset[index]
38 return (item, (key, seed))
39
40 def __getitem_for_read(self, index: int):
41 key = self.keys[index]
42 seed, logits_index, logits_value = self._get_saved_logits(key)
43 with AugRandomContext(seed=seed):
44 item = self.dataset[index]
45 return (item, (logits_index, logits_value, np.int32(seed)))
46
47 def _get_saved_logits(self, key: str):
48 manager = self.get_manager()
49 bstr: bytes = manager.read(key)
50 # parse the augmentation seed
51 seed = int(np.frombuffer(bstr[:4], dtype=np.int32))
52 # parse the logits index and value
53 # copy logits_index and logits_value to avoid warning of written flag from PyTorch
54 bstr = bstr[4:]
55 logits_index = np.frombuffer(
56 bstr[:self.topk * 2], dtype=np.int16).copy()
57 bstr = bstr[self.topk * 2:]
58 logits_value = np.frombuffer(
59 bstr[:self.topk * 2], dtype=np.float16).copy()
60 return seed, logits_index, logits_value
61
62 def _build_manager(self, logits_path: str):
63 # topk * [idx, value] * 2 bytes for logits + 4 bytes for seed
64 item_size = self.topk * 2 * 2 + 4
65 rank = get_rank()
66 return TxtManager(logits_path, item_size, rank)
67
68 def set_epoch(self, epoch: int):
69 self.epoch.value = epoch
70 self._manager = (None, None)
71
72 def get_manager(self):
73 epoch = self.epoch.value

Callers 1

build_loaderFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected