similar to normal implementation of distributed sampler, except implementation is at the batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
| 73 | self.epoch = epoch |
| 74 | |
| 75 | class DistributedBatchSampler(data.sampler.BatchSampler): |
| 76 | """ |
| 77 | similar to normal implementation of distributed sampler, except implementation is at the |
| 78 | batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary |
| 79 | data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. |
| 80 | """ |
| 81 | def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False): |
| 82 | super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) |
| 83 | if rank == -1: |
| 84 | assert False, 'should not be here' |
| 85 | rank = torch.distributed.get_rank() |
| 86 | self.rank = rank |
| 87 | self.world_size = world_size |
| 88 | self.sampler.wrap_around = 0 |
| 89 | self.wrap_around = 0 |
| 90 | self.wrap_last = wrap_last |
| 91 | self.start_iter = 0 |
| 92 | |
| 93 | def __iter__(self): |
| 94 | batch = [] |
| 95 | last_batch = None |
| 96 | i = 0 |
| 97 | for idx in self.data_iterator(self.sampler, wrap_around=False): |
| 98 | batch.append(idx) |
| 99 | if len(batch) == self.batch_size: |
| 100 | tbatch = self._batch(batch) |
| 101 | if i >= self.start_iter: |
| 102 | yield tbatch |
| 103 | self.start_iter = 0 |
| 104 | i += 1 |
| 105 | last_batch = np.array(list(tbatch)) |
| 106 | batch = [] |
| 107 | batch_len = len(batch) |
| 108 | if batch_len > 0 and not self.drop_last: |
| 109 | if self.wrap_last: |
| 110 | self.sampler.wrap_around -= (self.batch_size) |
| 111 | self.wrap_around += (len(batch)) |
| 112 | self.wrap_around %= self.batch_size |
| 113 | if isinstance(self.sampler, TransposedSampler): |
| 114 | for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)): |
| 115 | if i == 0: |
| 116 | continue |
| 117 | batch.append(idx) |
| 118 | new_batch_len = len(batch) |
| 119 | if len(batch) == self.batch_size: |
| 120 | break |
| 121 | yield self._batch(batch) |
| 122 | if self.wrap_last: |
| 123 | self.sampler.wrap_around += self.batch_size |
| 124 | |
| 125 | def data_iterator(self, _iter, wrap_around=False): |
| 126 | """iterates through data and handles wrap around""" |
| 127 | for i, idx in enumerate(_iter): |
| 128 | if i < self.wrap_around%self.batch_size: |
| 129 | continue |
| 130 | if wrap_around: |
| 131 | self.wrap_around += 1 |
| 132 | self.wrap_around %= self.batch_size |