https://github.com/galatolofederico/pytorch-balanced-batch/blob/master/sampler.py
| 9 | |
| 10 | # https://github.com/khornlund/pytorch-balanced-sampler |
| 11 | class BalancedBatchSampler(torch.utils.data.sampler.Sampler): |
| 12 | ''' |
| 13 | https://github.com/galatolofederico/pytorch-balanced-batch/blob/master/sampler.py |
| 14 | ''' |
| 15 | def __init__(self, dataset, labels=None): |
| 16 | self.labels = labels |
| 17 | self.dataset = collections.defaultdict(list) |
| 18 | self.balanced_max = 0 |
| 19 | # Save all the indices for all the classes |
| 20 | for idx in range(0, len(dataset)): |
| 21 | label = self._get_label(dataset, idx) |
| 22 | #break |
| 23 | self.dataset[label].append(idx) |
| 24 | self.balanced_max = max(self.balanced_max, len(self.dataset[label])) |
| 25 | #len(self.dataset[label]) if len(self.dataset[label]) > self.balanced_max else self.balanced_max |
| 26 | |
| 27 | # Oversample the classes with fewer elements than the max |
| 28 | for label in self.dataset: |
| 29 | while len(self.dataset[label]) < self.balanced_max: |
| 30 | self.dataset[label].append(random.choice(self.dataset[label])) |
| 31 | self.keys = list(self.dataset.keys()) |
| 32 | logging.warning(self.keys) |
| 33 | self.currentkey = 0 |
| 34 | self.indices = [-1] * len(self.keys) |
| 35 | |
| 36 | def __iter__(self): |
| 37 | while self.indices[self.currentkey] < self.balanced_max - 1: |
| 38 | self.indices[self.currentkey] += 1 |
| 39 | yield self.dataset[self.keys[self.currentkey]][self.indices[self.currentkey]] |
| 40 | self.currentkey = (self.currentkey + 1) % len(self.keys) |
| 41 | self.indices = [-1] * len(self.keys) |
| 42 | |
| 43 | def _get_label(self, dataset, idx): |
| 44 | #logging.warning(len(dataset)) |
| 45 | # logging.warning(dataset[idx]) |
| 46 | return dataset[idx][0][1]['category']#[1]['output'][0]['token'].split(' ')[1] |
| 47 | # def _get_label(self, dataset, idx, labels = None): |
| 48 | # if self.labels is not None: |
| 49 | # return self.labels[idx].item() |
| 50 | # else: |
| 51 | # raise Exception("You should pass the tensor of labels to the constructor as second argument") |
| 52 | |
| 53 | def __len__(self): |
| 54 | return self.balanced_max * len(self.keys) |
| 55 |
no outgoing calls
no test coverage detected