MCPcopy Index your code
hub / github.com/jindongwang/transferlearning / BalancedBatchSampler

Class BalancedBatchSampler

code/ASR/Adapter/balanced_sampler.py:11–54  ·  view source on GitHub ↗

https://github.com/galatolofederico/pytorch-balanced-batch/blob/master/sampler.py

Source from the content-addressed store, hash-verified

9
10# https://github.com/khornlund/pytorch-balanced-sampler
11class BalancedBatchSampler(torch.utils.data.sampler.Sampler):
12 '''
13 https://github.com/galatolofederico/pytorch-balanced-batch/blob/master/sampler.py
14 '''
15 def __init__(self, dataset, labels=None):
16 self.labels = labels
17 self.dataset = collections.defaultdict(list)
18 self.balanced_max = 0
19 # Save all the indices for all the classes
20 for idx in range(0, len(dataset)):
21 label = self._get_label(dataset, idx)
22 #break
23 self.dataset[label].append(idx)
24 self.balanced_max = max(self.balanced_max, len(self.dataset[label]))
25 #len(self.dataset[label]) if len(self.dataset[label]) > self.balanced_max else self.balanced_max
26
27 # Oversample the classes with fewer elements than the max
28 for label in self.dataset:
29 while len(self.dataset[label]) < self.balanced_max:
30 self.dataset[label].append(random.choice(self.dataset[label]))
31 self.keys = list(self.dataset.keys())
32 logging.warning(self.keys)
33 self.currentkey = 0
34 self.indices = [-1] * len(self.keys)
35
36 def __iter__(self):
37 while self.indices[self.currentkey] < self.balanced_max - 1:
38 self.indices[self.currentkey] += 1
39 yield self.dataset[self.keys[self.currentkey]][self.indices[self.currentkey]]
40 self.currentkey = (self.currentkey + 1) % len(self.keys)
41 self.indices = [-1] * len(self.keys)
42
43 def _get_label(self, dataset, idx):
44 #logging.warning(len(dataset))
45 # logging.warning(dataset[idx])
46 return dataset[idx][0][1]['category']#[1]['output'][0]['token'].split(' ')[1]
47 # def _get_label(self, dataset, idx, labels = None):
48 # if self.labels is not None:
49 # return self.labels[idx].item()
50 # else:
51 # raise Exception("You should pass the tensor of labels to the constructor as second argument")
52
53 def __len__(self):
54 return self.balanced_max * len(self.keys)
55

Callers 1

load_multilingual_dataFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected